bpo-38870: Implement round tripping support for typed AST in ast.unparse (GH-17797)

This commit is contained in:
Batuhan Taskaya 2020-05-17 02:04:12 +03:00 committed by GitHub
parent e966af7cff
commit dff92bb31f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 8 deletions

View file

@ -648,6 +648,7 @@ class _Unparser(NodeVisitor):
self._source = [] self._source = []
self._buffer = [] self._buffer = []
self._precedences = {} self._precedences = {}
self._type_ignores = {}
self._indent = 0 self._indent = 0
def interleave(self, inter, f, seq): def interleave(self, inter, f, seq):
@ -697,11 +698,15 @@ class _Unparser(NodeVisitor):
return value return value
@contextmanager @contextmanager
def block(self): def block(self, *, extra = None):
"""A context manager for preparing the source for blocks. It adds """A context manager for preparing the source for blocks. It adds
the character':', increases the indentation on enter and decreases the character':', increases the indentation on enter and decreases
the indentation on exit.""" the indentation on exit. If *extra* is given, it will be directly
appended after the colon character.
"""
self.write(":") self.write(":")
if extra:
self.write(extra)
self._indent += 1 self._indent += 1
yield yield
self._indent -= 1 self._indent -= 1
@ -748,6 +753,11 @@ class _Unparser(NodeVisitor):
if isinstance(node, Constant) and isinstance(node.value, str): if isinstance(node, Constant) and isinstance(node.value, str):
return node return node
def get_type_comment(self, node):
comment = self._type_ignores.get(node.lineno) or node.type_comment
if comment is not None:
return f" # type: {comment}"
def traverse(self, node): def traverse(self, node):
if isinstance(node, list): if isinstance(node, list):
for item in node: for item in node:
@ -770,7 +780,12 @@ class _Unparser(NodeVisitor):
self.traverse(node.body) self.traverse(node.body)
def visit_Module(self, node): def visit_Module(self, node):
self._type_ignores = {
ignore.lineno: f"ignore{ignore.tag}"
for ignore in node.type_ignores
}
self._write_docstring_and_traverse_body(node) self._write_docstring_and_traverse_body(node)
self._type_ignores.clear()
def visit_FunctionType(self, node): def visit_FunctionType(self, node):
with self.delimit("(", ")"): with self.delimit("(", ")"):
@ -811,6 +826,8 @@ class _Unparser(NodeVisitor):
self.traverse(target) self.traverse(target)
self.write(" = ") self.write(" = ")
self.traverse(node.value) self.traverse(node.value)
if type_comment := self.get_type_comment(node):
self.write(type_comment)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
self.fill() self.fill()
@ -966,7 +983,7 @@ class _Unparser(NodeVisitor):
if node.returns: if node.returns:
self.write(" -> ") self.write(" -> ")
self.traverse(node.returns) self.traverse(node.returns)
with self.block(): with self.block(extra=self.get_type_comment(node)):
self._write_docstring_and_traverse_body(node) self._write_docstring_and_traverse_body(node)
def visit_For(self, node): def visit_For(self, node):
@ -980,7 +997,7 @@ class _Unparser(NodeVisitor):
self.traverse(node.target) self.traverse(node.target)
self.write(" in ") self.write(" in ")
self.traverse(node.iter) self.traverse(node.iter)
with self.block(): with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body) self.traverse(node.body)
if node.orelse: if node.orelse:
self.fill("else") self.fill("else")
@ -1018,13 +1035,13 @@ class _Unparser(NodeVisitor):
def visit_With(self, node): def visit_With(self, node):
self.fill("with ") self.fill("with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items) self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(): with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body) self.traverse(node.body)
def visit_AsyncWith(self, node): def visit_AsyncWith(self, node):
self.fill("async with ") self.fill("async with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items) self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(): with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body) self.traverse(node.body)
def visit_JoinedStr(self, node): def visit_JoinedStr(self, node):

View file

@ -108,12 +108,12 @@ with f() as x, g() as y:
suite1 suite1
""" """
docstring_prefixes = [ docstring_prefixes = (
"", "",
"class foo:\n ", "class foo:\n ",
"def foo():\n ", "def foo():\n ",
"async def foo():\n ", "async def foo():\n ",
] )
class ASTTestCase(unittest.TestCase): class ASTTestCase(unittest.TestCase):
def assertASTEqual(self, ast1, ast2): def assertASTEqual(self, ast1, ast2):
@ -340,6 +340,37 @@ class UnparseTestCase(ASTTestCase):
): ):
self.check_ast_roundtrip(function_type, mode="func_type") self.check_ast_roundtrip(function_type, mode="func_type")
def test_type_comments(self):
for statement in (
"a = 5 # type:",
"a = 5 # type: int",
"a = 5 # type: int and more",
"def x(): # type: () -> None\n\tpass",
"def x(y): # type: (int) -> None and more\n\tpass",
"async def x(): # type: () -> None\n\tpass",
"async def x(y): # type: (int) -> None and more\n\tpass",
"for x in y: # type: int\n\tpass",
"async for x in y: # type: int\n\tpass",
"with x(): # type: int\n\tpass",
"async with x(): # type: int\n\tpass"
):
self.check_ast_roundtrip(statement, type_comments=True)
def test_type_ignore(self):
for statement in (
"a = 5 # type: ignore",
"a = 5 # type: ignore and more",
"def x(): # type: ignore\n\tpass",
"def x(y): # type: ignore and more\n\tpass",
"async def x(): # type: ignore\n\tpass",
"async def x(y): # type: ignore and more\n\tpass",
"for x in y: # type: ignore\n\tpass",
"async for x in y: # type: ignore\n\tpass",
"with x(): # type: ignore\n\tpass",
"async with x(): # type: ignore\n\tpass"
):
self.check_ast_roundtrip(statement, type_comments=True)
class CosmeticTestCase(ASTTestCase): class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecesary additions""" """Test if there are cosmetic issues caused by unnecesary additions"""