bpo-38870: Refactor delimiting with context managers in ast.unparse (GH-17612)

Co-Authored-By: Victor Stinner <vstinner@python.org>
Co-authored-by: Pablo Galindo <pablogsal@gmail.com>
This commit is contained in:
Batuhan Taşkaya 2019-12-23 19:11:00 +03:00 committed by Pablo Galindo
parent 9f9dac0a4e
commit 4b3b1226e8

View file

@ -26,6 +26,7 @@
""" """
import sys import sys
from _ast import * from _ast import *
from contextlib import contextmanager, nullcontext
def parse(source, filename='<unknown>', mode='exec', *, def parse(source, filename='<unknown>', mode='exec', *,
@ -613,6 +614,21 @@ class _Unparser(NodeVisitor):
def block(self): def block(self):
return self._Block(self) return self._Block(self)
@contextmanager
def delimit(self, start, end):
"""A context manager for preparing the source for expressions. It adds
*start* to the buffer and enters, after exit it adds *end*."""
self.write(start)
yield
self.write(end)
def delimit_if(self, start, end, condition):
if condition:
return self.delimit(start, end)
else:
return nullcontext()
def traverse(self, node): def traverse(self, node):
if isinstance(node, list): if isinstance(node, list):
for item in node: for item in node:
@ -636,11 +652,10 @@ class _Unparser(NodeVisitor):
self.traverse(node.value) self.traverse(node.value)
def visit_NamedExpr(self, node): def visit_NamedExpr(self, node):
self.write("(") with self.delimit("(", ")"):
self.traverse(node.target) self.traverse(node.target)
self.write(" := ") self.write(" := ")
self.traverse(node.value) self.traverse(node.value)
self.write(")")
def visit_Import(self, node): def visit_Import(self, node):
self.fill("import ") self.fill("import ")
@ -669,11 +684,8 @@ class _Unparser(NodeVisitor):
def visit_AnnAssign(self, node): def visit_AnnAssign(self, node):
self.fill() self.fill()
if not node.simple and isinstance(node.target, Name): with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)):
self.write("(") self.traverse(node.target)
self.traverse(node.target)
if not node.simple and isinstance(node.target, Name):
self.write(")")
self.write(": ") self.write(": ")
self.traverse(node.annotation) self.traverse(node.annotation)
if node.value: if node.value:
@ -715,28 +727,25 @@ class _Unparser(NodeVisitor):
self.interleave(lambda: self.write(", "), self.write, node.names) self.interleave(lambda: self.write(", "), self.write, node.names)
def visit_Await(self, node): def visit_Await(self, node):
self.write("(") with self.delimit("(", ")"):
self.write("await") self.write("await")
if node.value: if node.value:
self.write(" ") self.write(" ")
self.traverse(node.value) self.traverse(node.value)
self.write(")")
def visit_Yield(self, node): def visit_Yield(self, node):
self.write("(") with self.delimit("(", ")"):
self.write("yield") self.write("yield")
if node.value: if node.value:
self.write(" ") self.write(" ")
self.traverse(node.value) self.traverse(node.value)
self.write(")")
def visit_YieldFrom(self, node): def visit_YieldFrom(self, node):
self.write("(") with self.delimit("(", ")"):
self.write("yield from") self.write("yield from")
if node.value: if node.value:
self.write(" ") self.write(" ")
self.traverse(node.value) self.traverse(node.value)
self.write(")")
def visit_Raise(self, node): def visit_Raise(self, node):
self.fill("raise") self.fill("raise")
@ -782,21 +791,20 @@ class _Unparser(NodeVisitor):
self.fill("@") self.fill("@")
self.traverse(deco) self.traverse(deco)
self.fill("class " + node.name) self.fill("class " + node.name)
self.write("(") with self.delimit("(", ")"):
comma = False comma = False
for e in node.bases: for e in node.bases:
if comma: if comma:
self.write(", ") self.write(", ")
else: else:
comma = True comma = True
self.traverse(e) self.traverse(e)
for e in node.keywords: for e in node.keywords:
if comma: if comma:
self.write(", ") self.write(", ")
else: else:
comma = True comma = True
self.traverse(e) self.traverse(e)
self.write(")")
with self.block(): with self.block():
self.traverse(node.body) self.traverse(node.body)
@ -812,10 +820,10 @@ class _Unparser(NodeVisitor):
for deco in node.decorator_list: for deco in node.decorator_list:
self.fill("@") self.fill("@")
self.traverse(deco) self.traverse(deco)
def_str = fill_suffix + " " + node.name + "(" def_str = fill_suffix + " " + node.name
self.fill(def_str) self.fill(def_str)
self.traverse(node.args) with self.delimit("(", ")"):
self.write(")") self.traverse(node.args)
if node.returns: if node.returns:
self.write(" -> ") self.write(" -> ")
self.traverse(node.returns) self.traverse(node.returns)
@ -931,13 +939,12 @@ class _Unparser(NodeVisitor):
def visit_Constant(self, node): def visit_Constant(self, node):
value = node.value value = node.value
if isinstance(value, tuple): if isinstance(value, tuple):
self.write("(") with self.delimit("(", ")"):
if len(value) == 1: if len(value) == 1:
self._write_constant(value[0]) self._write_constant(value[0])
self.write(",") self.write(",")
else: else:
self.interleave(lambda: self.write(", "), self._write_constant, value) self.interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is ...: elif value is ...:
self.write("...") self.write("...")
else: else:
@ -946,39 +953,34 @@ class _Unparser(NodeVisitor):
self._write_constant(node.value) self._write_constant(node.value)
def visit_List(self, node): def visit_List(self, node):
self.write("[") with self.delimit("[", "]"):
self.interleave(lambda: self.write(", "), self.traverse, node.elts) self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write("]")
def visit_ListComp(self, node): def visit_ListComp(self, node):
self.write("[") with self.delimit("[", "]"):
self.traverse(node.elt) self.traverse(node.elt)
for gen in node.generators: for gen in node.generators:
self.traverse(gen) self.traverse(gen)
self.write("]")
def visit_GeneratorExp(self, node): def visit_GeneratorExp(self, node):
self.write("(") with self.delimit("(", ")"):
self.traverse(node.elt) self.traverse(node.elt)
for gen in node.generators: for gen in node.generators:
self.traverse(gen) self.traverse(gen)
self.write(")")
def visit_SetComp(self, node): def visit_SetComp(self, node):
self.write("{") with self.delimit("{", "}"):
self.traverse(node.elt) self.traverse(node.elt)
for gen in node.generators: for gen in node.generators:
self.traverse(gen) self.traverse(gen)
self.write("}")
def visit_DictComp(self, node): def visit_DictComp(self, node):
self.write("{") with self.delimit("{", "}"):
self.traverse(node.key) self.traverse(node.key)
self.write(": ") self.write(": ")
self.traverse(node.value) self.traverse(node.value)
for gen in node.generators: for gen in node.generators:
self.traverse(gen) self.traverse(gen)
self.write("}")
def visit_comprehension(self, node): def visit_comprehension(self, node):
if node.is_async: if node.is_async:
@ -993,24 +995,20 @@ class _Unparser(NodeVisitor):
self.traverse(if_clause) self.traverse(if_clause)
def visit_IfExp(self, node): def visit_IfExp(self, node):
self.write("(") with self.delimit("(", ")"):
self.traverse(node.body) self.traverse(node.body)
self.write(" if ") self.write(" if ")
self.traverse(node.test) self.traverse(node.test)
self.write(" else ") self.write(" else ")
self.traverse(node.orelse) self.traverse(node.orelse)
self.write(")")
def visit_Set(self, node): def visit_Set(self, node):
if not node.elts: if not node.elts:
raise ValueError("Set node should has at least one item") raise ValueError("Set node should has at least one item")
self.write("{") with self.delimit("{", "}"):
self.interleave(lambda: self.write(", "), self.traverse, node.elts) self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write("}")
def visit_Dict(self, node): def visit_Dict(self, node):
self.write("{")
def write_key_value_pair(k, v): def write_key_value_pair(k, v):
self.traverse(k) self.traverse(k)
self.write(": ") self.write(": ")
@ -1026,29 +1024,27 @@ class _Unparser(NodeVisitor):
else: else:
write_key_value_pair(k, v) write_key_value_pair(k, v)
self.interleave( with self.delimit("{", "}"):
lambda: self.write(", "), write_item, zip(node.keys, node.values) self.interleave(
) lambda: self.write(", "), write_item, zip(node.keys, node.values)
self.write("}") )
def visit_Tuple(self, node): def visit_Tuple(self, node):
self.write("(") with self.delimit("(", ")"):
if len(node.elts) == 1: if len(node.elts) == 1:
elt = node.elts[0] elt = node.elts[0]
self.traverse(elt) self.traverse(elt)
self.write(",") self.write(",")
else: else:
self.interleave(lambda: self.write(", "), self.traverse, node.elts) self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write(")")
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
self.write("(") with self.delimit("(", ")"):
self.write(self.unop[node.op.__class__.__name__]) self.write(self.unop[node.op.__class__.__name__])
self.write(" ") self.write(" ")
self.traverse(node.operand) self.traverse(node.operand)
self.write(")")
binop = { binop = {
"Add": "+", "Add": "+",
@ -1067,11 +1063,10 @@ class _Unparser(NodeVisitor):
} }
def visit_BinOp(self, node): def visit_BinOp(self, node):
self.write("(") with self.delimit("(", ")"):
self.traverse(node.left) self.traverse(node.left)
self.write(" " + self.binop[node.op.__class__.__name__] + " ") self.write(" " + self.binop[node.op.__class__.__name__] + " ")
self.traverse(node.right) self.traverse(node.right)
self.write(")")
cmpops = { cmpops = {
"Eq": "==", "Eq": "==",
@ -1087,20 +1082,18 @@ class _Unparser(NodeVisitor):
} }
def visit_Compare(self, node): def visit_Compare(self, node):
self.write("(") with self.delimit("(", ")"):
self.traverse(node.left) self.traverse(node.left)
for o, e in zip(node.ops, node.comparators): for o, e in zip(node.ops, node.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ") self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.traverse(e) self.traverse(e)
self.write(")")
boolops = {And: "and", Or: "or"} boolops = {"And": "and", "Or": "or"}
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
self.write("(") with self.delimit("(", ")"):
s = " %s " % self.boolops[node.op.__class__] s = " %s " % self.boolops[node.op.__class__.__name__]
self.interleave(lambda: self.write(s), self.traverse, node.values) self.interleave(lambda: self.write(s), self.traverse, node.values)
self.write(")")
def visit_Attribute(self, node): def visit_Attribute(self, node):
self.traverse(node.value) self.traverse(node.value)
@ -1114,27 +1107,25 @@ class _Unparser(NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
self.traverse(node.func) self.traverse(node.func)
self.write("(") with self.delimit("(", ")"):
comma = False comma = False
for e in node.args: for e in node.args:
if comma: if comma:
self.write(", ") self.write(", ")
else: else:
comma = True comma = True
self.traverse(e) self.traverse(e)
for e in node.keywords: for e in node.keywords:
if comma: if comma:
self.write(", ") self.write(", ")
else: else:
comma = True comma = True
self.traverse(e) self.traverse(e)
self.write(")")
def visit_Subscript(self, node): def visit_Subscript(self, node):
self.traverse(node.value) self.traverse(node.value)
self.write("[") with self.delimit("[", "]"):
self.traverse(node.slice) self.traverse(node.slice)
self.write("]")
def visit_Starred(self, node): def visit_Starred(self, node):
self.write("*") self.write("*")
@ -1225,12 +1216,11 @@ class _Unparser(NodeVisitor):
self.traverse(node.value) self.traverse(node.value)
def visit_Lambda(self, node): def visit_Lambda(self, node):
self.write("(") with self.delimit("(", ")"):
self.write("lambda ") self.write("lambda ")
self.traverse(node.args) self.traverse(node.args)
self.write(": ") self.write(": ")
self.traverse(node.body) self.traverse(node.body)
self.write(")")
def visit_alias(self, node): def visit_alias(self, node):
self.write(node.name) self.write(node.name)