mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-38870: Implement round tripping support for typed AST in ast.unparse (GH-17797)
This commit is contained in:
parent
e966af7cff
commit
dff92bb31f
2 changed files with 56 additions and 8 deletions
29
Lib/ast.py
29
Lib/ast.py
|
@ -648,6 +648,7 @@ class _Unparser(NodeVisitor):
|
|||
self._source = []
|
||||
self._buffer = []
|
||||
self._precedences = {}
|
||||
self._type_ignores = {}
|
||||
self._indent = 0
|
||||
|
||||
def interleave(self, inter, f, seq):
|
||||
|
@ -697,11 +698,15 @@ class _Unparser(NodeVisitor):
|
|||
return value
|
||||
|
||||
@contextmanager
|
||||
def block(self):
|
||||
def block(self, *, extra = None):
|
||||
"""A context manager for preparing the source for blocks. It adds
|
||||
the character':', increases the indentation on enter and decreases
|
||||
the indentation on exit."""
|
||||
the indentation on exit. If *extra* is given, it will be directly
|
||||
appended after the colon character.
|
||||
"""
|
||||
self.write(":")
|
||||
if extra:
|
||||
self.write(extra)
|
||||
self._indent += 1
|
||||
yield
|
||||
self._indent -= 1
|
||||
|
@ -748,6 +753,11 @@ class _Unparser(NodeVisitor):
|
|||
if isinstance(node, Constant) and isinstance(node.value, str):
|
||||
return node
|
||||
|
||||
def get_type_comment(self, node):
|
||||
comment = self._type_ignores.get(node.lineno) or node.type_comment
|
||||
if comment is not None:
|
||||
return f" # type: {comment}"
|
||||
|
||||
def traverse(self, node):
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
|
@ -770,7 +780,12 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(node.body)
|
||||
|
||||
def visit_Module(self, node):
|
||||
self._type_ignores = {
|
||||
ignore.lineno: f"ignore{ignore.tag}"
|
||||
for ignore in node.type_ignores
|
||||
}
|
||||
self._write_docstring_and_traverse_body(node)
|
||||
self._type_ignores.clear()
|
||||
|
||||
def visit_FunctionType(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
|
@ -811,6 +826,8 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(target)
|
||||
self.write(" = ")
|
||||
self.traverse(node.value)
|
||||
if type_comment := self.get_type_comment(node):
|
||||
self.write(type_comment)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
self.fill()
|
||||
|
@ -966,7 +983,7 @@ class _Unparser(NodeVisitor):
|
|||
if node.returns:
|
||||
self.write(" -> ")
|
||||
self.traverse(node.returns)
|
||||
with self.block():
|
||||
with self.block(extra=self.get_type_comment(node)):
|
||||
self._write_docstring_and_traverse_body(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
|
@ -980,7 +997,7 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(node.target)
|
||||
self.write(" in ")
|
||||
self.traverse(node.iter)
|
||||
with self.block():
|
||||
with self.block(extra=self.get_type_comment(node)):
|
||||
self.traverse(node.body)
|
||||
if node.orelse:
|
||||
self.fill("else")
|
||||
|
@ -1018,13 +1035,13 @@ class _Unparser(NodeVisitor):
|
|||
def visit_With(self, node):
|
||||
self.fill("with ")
|
||||
self.interleave(lambda: self.write(", "), self.traverse, node.items)
|
||||
with self.block():
|
||||
with self.block(extra=self.get_type_comment(node)):
|
||||
self.traverse(node.body)
|
||||
|
||||
def visit_AsyncWith(self, node):
|
||||
self.fill("async with ")
|
||||
self.interleave(lambda: self.write(", "), self.traverse, node.items)
|
||||
with self.block():
|
||||
with self.block(extra=self.get_type_comment(node)):
|
||||
self.traverse(node.body)
|
||||
|
||||
def visit_JoinedStr(self, node):
|
||||
|
|
|
@ -108,12 +108,12 @@ with f() as x, g() as y:
|
|||
suite1
|
||||
"""
|
||||
|
||||
docstring_prefixes = [
|
||||
docstring_prefixes = (
|
||||
"",
|
||||
"class foo:\n ",
|
||||
"def foo():\n ",
|
||||
"async def foo():\n ",
|
||||
]
|
||||
)
|
||||
|
||||
class ASTTestCase(unittest.TestCase):
|
||||
def assertASTEqual(self, ast1, ast2):
|
||||
|
@ -340,6 +340,37 @@ class UnparseTestCase(ASTTestCase):
|
|||
):
|
||||
self.check_ast_roundtrip(function_type, mode="func_type")
|
||||
|
||||
def test_type_comments(self):
|
||||
for statement in (
|
||||
"a = 5 # type:",
|
||||
"a = 5 # type: int",
|
||||
"a = 5 # type: int and more",
|
||||
"def x(): # type: () -> None\n\tpass",
|
||||
"def x(y): # type: (int) -> None and more\n\tpass",
|
||||
"async def x(): # type: () -> None\n\tpass",
|
||||
"async def x(y): # type: (int) -> None and more\n\tpass",
|
||||
"for x in y: # type: int\n\tpass",
|
||||
"async for x in y: # type: int\n\tpass",
|
||||
"with x(): # type: int\n\tpass",
|
||||
"async with x(): # type: int\n\tpass"
|
||||
):
|
||||
self.check_ast_roundtrip(statement, type_comments=True)
|
||||
|
||||
def test_type_ignore(self):
|
||||
for statement in (
|
||||
"a = 5 # type: ignore",
|
||||
"a = 5 # type: ignore and more",
|
||||
"def x(): # type: ignore\n\tpass",
|
||||
"def x(y): # type: ignore and more\n\tpass",
|
||||
"async def x(): # type: ignore\n\tpass",
|
||||
"async def x(y): # type: ignore and more\n\tpass",
|
||||
"for x in y: # type: ignore\n\tpass",
|
||||
"async for x in y: # type: ignore\n\tpass",
|
||||
"with x(): # type: ignore\n\tpass",
|
||||
"async with x(): # type: ignore\n\tpass"
|
||||
):
|
||||
self.check_ast_roundtrip(statement, type_comments=True)
|
||||
|
||||
|
||||
class CosmeticTestCase(ASTTestCase):
|
||||
"""Test if there are cosmetic issues caused by unnecesary additions"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue