bpo-38870: Expose a function to unparse an ast object in the ast module (GH-17302)

Add ast.unparse() as a function in the ast module that can be used to unparse an
ast.AST object and produce a string with code that would produce an equivalent ast.AST
object when parsed.
This commit is contained in:
Pablo Galindo 2019-11-24 23:02:40 +00:00 committed by GitHub
parent 6bf644ec82
commit 27fc3b6f3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 772 additions and 751 deletions

View file

@ -161,6 +161,19 @@ and classes for traversing abstract syntax trees:
Added ``type_comments``, ``mode='func_type'`` and ``feature_version``. Added ``type_comments``, ``mode='func_type'`` and ``feature_version``.
.. function:: unparse(ast_obj)
Unparse an :class:`ast.AST` object and generate a string with code
that would produce an equivalent :class:`ast.AST` object if parsed
back with :func:`ast.parse`.
.. warning::
The produced code string will not necesarily be equal to the original
code that generated the :class:`ast.AST` object.
.. versionadded:: 3.9
.. function:: literal_eval(node_or_string) .. function:: literal_eval(node_or_string)
Safely evaluate an expression node or a string containing a Python literal or Safely evaluate an expression node or a string containing a Python literal or

View file

@ -121,6 +121,11 @@ Added the *indent* option to :func:`~ast.dump` which allows it to produce a
multiline indented output. multiline indented output.
(Contributed by Serhiy Storchaka in :issue:`37995`.) (Contributed by Serhiy Storchaka in :issue:`37995`.)
Added the :func:`ast.unparse` as a function in the :mod:`ast` module that can
be used to unparse an :class:`ast.AST` object and produce a string with code
that would produce an equivalent :class:`ast.AST` object when parsed.
(Contributed by Pablo Galindo and Batuhan Taskaya in :issue:`38870`.)
asyncio asyncio
------- -------

View file

@ -24,7 +24,9 @@
:copyright: Copyright 2008 by Armin Ronacher. :copyright: Copyright 2008 by Armin Ronacher.
:license: Python License. :license: Python License.
""" """
import sys
from _ast import * from _ast import *
from contextlib import contextmanager
def parse(source, filename='<unknown>', mode='exec', *, def parse(source, filename='<unknown>', mode='exec', *,
@ -551,6 +553,697 @@ _const_node_type_names = {
type(...): 'Ellipsis', type(...): 'Ellipsis',
} }
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
class _Unparser(NodeVisitor):
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded."""
def __init__(self):
self._source = []
self._buffer = []
self._indent = 0
def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
def fill(self, text=""):
"""Indent a piece of text and append it, according to the current
indentation level"""
self.write("\n" + " " * self._indent + text)
def write(self, text):
"""Append a piece of text"""
self._source.append(text)
def buffer_writer(self, text):
self._buffer.append(text)
@property
def buffer(self):
value = "".join(self._buffer)
self._buffer.clear()
return value
@contextmanager
def block(self):
"""A context manager for preparing the source for blocks. It adds
the character':', increases the indentation on enter and decreases
the indentation on exit."""
self.write(":")
self._indent += 1
yield
self._indent -= 1
def traverse(self, node):
if isinstance(node, list):
for item in node:
self.traverse(item)
else:
super().visit(node)
def visit(self, node):
"""Outputs a source code string that, if converted back to an ast
(using ast.parse) will generate an AST equivalent to *node*"""
self._source = []
self.traverse(node)
return "".join(self._source)
def visit_Module(self, node):
for subnode in node.body:
self.traverse(subnode)
def visit_Expr(self, node):
self.fill()
self.traverse(node.value)
def visit_NamedExpr(self, node):
self.write("(")
self.traverse(node.target)
self.write(" := ")
self.traverse(node.value)
self.write(")")
def visit_Import(self, node):
self.fill("import ")
self.interleave(lambda: self.write(", "), self.traverse, node.names)
def visit_ImportFrom(self, node):
self.fill("from ")
self.write("." * node.level)
if node.module:
self.write(node.module)
self.write(" import ")
self.interleave(lambda: self.write(", "), self.traverse, node.names)
def visit_Assign(self, node):
self.fill()
for target in node.targets:
self.traverse(target)
self.write(" = ")
self.traverse(node.value)
def visit_AugAssign(self, node):
self.fill()
self.traverse(node.target)
self.write(" " + self.binop[node.op.__class__.__name__] + "= ")
self.traverse(node.value)
def visit_AnnAssign(self, node):
self.fill()
if not node.simple and isinstance(node.target, Name):
self.write("(")
self.traverse(node.target)
if not node.simple and isinstance(node.target, Name):
self.write(")")
self.write(": ")
self.traverse(node.annotation)
if node.value:
self.write(" = ")
self.traverse(node.value)
def visit_Return(self, node):
self.fill("return")
if node.value:
self.write(" ")
self.traverse(node.value)
def visit_Pass(self, node):
self.fill("pass")
def visit_Break(self, node):
self.fill("break")
def visit_Continue(self, node):
self.fill("continue")
def visit_Delete(self, node):
self.fill("del ")
self.interleave(lambda: self.write(", "), self.traverse, node.targets)
def visit_Assert(self, node):
self.fill("assert ")
self.traverse(node.test)
if node.msg:
self.write(", ")
self.traverse(node.msg)
def visit_Global(self, node):
self.fill("global ")
self.interleave(lambda: self.write(", "), self.write, node.names)
def visit_Nonlocal(self, node):
self.fill("nonlocal ")
self.interleave(lambda: self.write(", "), self.write, node.names)
def visit_Await(self, node):
self.write("(")
self.write("await")
if node.value:
self.write(" ")
self.traverse(node.value)
self.write(")")
def visit_Yield(self, node):
self.write("(")
self.write("yield")
if node.value:
self.write(" ")
self.traverse(node.value)
self.write(")")
def visit_YieldFrom(self, node):
self.write("(")
self.write("yield from")
if node.value:
self.write(" ")
self.traverse(node.value)
self.write(")")
def visit_Raise(self, node):
self.fill("raise")
if not node.exc:
if node.cause:
raise ValueError(f"Node can't use cause without an exception.")
return
self.write(" ")
self.traverse(node.exc)
if node.cause:
self.write(" from ")
self.traverse(node.cause)
def visit_Try(self, node):
self.fill("try")
with self.block():
self.traverse(node.body)
for ex in node.handlers:
self.traverse(ex)
if node.orelse:
self.fill("else")
with self.block():
self.traverse(node.orelse)
if node.finalbody:
self.fill("finally")
with self.block():
self.traverse(node.finalbody)
def visit_ExceptHandler(self, node):
self.fill("except")
if node.type:
self.write(" ")
self.traverse(node.type)
if node.name:
self.write(" as ")
self.write(node.name)
with self.block():
self.traverse(node.body)
def visit_ClassDef(self, node):
self.write("\n")
for deco in node.decorator_list:
self.fill("@")
self.traverse(deco)
self.fill("class " + node.name)
self.write("(")
comma = False
for e in node.bases:
if comma:
self.write(", ")
else:
comma = True
self.traverse(e)
for e in node.keywords:
if comma:
self.write(", ")
else:
comma = True
self.traverse(e)
self.write(")")
with self.block():
self.traverse(node.body)
def visit_FunctionDef(self, node):
self.__FunctionDef_helper(node, "def")
def visit_AsyncFunctionDef(self, node):
self.__FunctionDef_helper(node, "async def")
def __FunctionDef_helper(self, node, fill_suffix):
self.write("\n")
for deco in node.decorator_list:
self.fill("@")
self.traverse(deco)
def_str = fill_suffix + " " + node.name + "("
self.fill(def_str)
self.traverse(node.args)
self.write(")")
if node.returns:
self.write(" -> ")
self.traverse(node.returns)
with self.block():
self.traverse(node.body)
def visit_For(self, node):
self.__For_helper("for ", node)
def visit_AsyncFor(self, node):
self.__For_helper("async for ", node)
def __For_helper(self, fill, node):
self.fill(fill)
self.traverse(node.target)
self.write(" in ")
self.traverse(node.iter)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill("else")
with self.block():
self.traverse(node.orelse)
def visit_If(self, node):
self.fill("if ")
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# collapse nested ifs into equivalent elifs.
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
node = node.orelse[0]
self.fill("elif ")
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# final else
if node.orelse:
self.fill("else")
with self.block():
self.traverse(node.orelse)
def visit_While(self, node):
self.fill("while ")
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill("else")
with self.block():
self.traverse(node.orelse)
def visit_With(self, node):
self.fill("with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block():
self.traverse(node.body)
def visit_AsyncWith(self, node):
self.fill("async with ")
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block():
self.traverse(node.body)
def visit_JoinedStr(self, node):
self.write("f")
self._fstring_JoinedStr(node, self.buffer_writer)
self.write(repr(self.buffer))
def visit_FormattedValue(self, node):
self.write("f")
self._fstring_FormattedValue(node, self.buffer_writer)
self.write(repr(self.buffer))
def _fstring_JoinedStr(self, node, write):
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
def _fstring_Constant(self, node, write):
if not isinstance(node.value, str):
raise ValueError("Constants inside JoinedStr should be a string.")
value = node.value.replace("{", "{{").replace("}", "}}")
write(value)
def _fstring_FormattedValue(self, node, write):
write("{")
expr = type(self)().visit(node.value).rstrip("\n")
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
write(expr)
if node.conversion != -1:
conversion = chr(node.conversion)
if conversion not in "sra":
raise ValueError("Unknown f-string conversion.")
write(f"!{conversion}")
if node.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
meth(node.format_spec, write)
write("}")
def visit_Name(self, node):
self.write(node.id)
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(value).replace("inf", _INFSTR))
else:
self.write(repr(value))
def visit_Constant(self, node):
value = node.value
if isinstance(value, tuple):
self.write("(")
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
self.interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is ...:
self.write("...")
else:
if node.kind == "u":
self.write("u")
self._write_constant(node.value)
def visit_List(self, node):
self.write("[")
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write("]")
def visit_ListComp(self, node):
self.write("[")
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
self.write("]")
def visit_GeneratorExp(self, node):
self.write("(")
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
self.write(")")
def visit_SetComp(self, node):
self.write("{")
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
self.write("}")
def visit_DictComp(self, node):
self.write("{")
self.traverse(node.key)
self.write(": ")
self.traverse(node.value)
for gen in node.generators:
self.traverse(gen)
self.write("}")
def visit_comprehension(self, node):
if node.is_async:
self.write(" async for ")
else:
self.write(" for ")
self.traverse(node.target)
self.write(" in ")
self.traverse(node.iter)
for if_clause in node.ifs:
self.write(" if ")
self.traverse(if_clause)
def visit_IfExp(self, node):
self.write("(")
self.traverse(node.body)
self.write(" if ")
self.traverse(node.test)
self.write(" else ")
self.traverse(node.orelse)
self.write(")")
def visit_Set(self, node):
if not node.elts:
raise ValueError("Set node should has at least one item")
self.write("{")
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write("}")
def visit_Dict(self, node):
self.write("{")
def write_key_value_pair(k, v):
self.traverse(k)
self.write(": ")
self.traverse(v)
def write_item(item):
k, v = item
if k is None:
# for dictionary unpacking operator in dicts {**{'y': 2}}
# see PEP 448 for details
self.write("**")
self.traverse(v)
else:
write_key_value_pair(k, v)
self.interleave(
lambda: self.write(", "), write_item, zip(node.keys, node.values)
)
self.write("}")
def visit_Tuple(self, node):
self.write("(")
if len(node.elts) == 1:
elt = node.elts[0]
self.traverse(elt)
self.write(",")
else:
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
self.write(")")
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
def visit_UnaryOp(self, node):
self.write("(")
self.write(self.unop[node.op.__class__.__name__])
self.write(" ")
self.traverse(node.operand)
self.write(")")
binop = {
"Add": "+",
"Sub": "-",
"Mult": "*",
"MatMult": "@",
"Div": "/",
"Mod": "%",
"LShift": "<<",
"RShift": ">>",
"BitOr": "|",
"BitXor": "^",
"BitAnd": "&",
"FloorDiv": "//",
"Pow": "**",
}
def visit_BinOp(self, node):
self.write("(")
self.traverse(node.left)
self.write(" " + self.binop[node.op.__class__.__name__] + " ")
self.traverse(node.right)
self.write(")")
cmpops = {
"Eq": "==",
"NotEq": "!=",
"Lt": "<",
"LtE": "<=",
"Gt": ">",
"GtE": ">=",
"Is": "is",
"IsNot": "is not",
"In": "in",
"NotIn": "not in",
}
def visit_Compare(self, node):
self.write("(")
self.traverse(node.left)
for o, e in zip(node.ops, node.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.traverse(e)
self.write(")")
boolops = {And: "and", Or: "or"}
def visit_BoolOp(self, node):
self.write("(")
s = " %s " % self.boolops[node.op.__class__]
self.interleave(lambda: self.write(s), self.traverse, node.values)
self.write(")")
def visit_Attribute(self, node):
self.traverse(node.value)
# Special case: 3.__abs__() is a syntax error, so if node.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(node.value, Constant) and isinstance(node.value.value, int):
self.write(" ")
self.write(".")
self.write(node.attr)
def visit_Call(self, node):
self.traverse(node.func)
self.write("(")
comma = False
for e in node.args:
if comma:
self.write(", ")
else:
comma = True
self.traverse(e)
for e in node.keywords:
if comma:
self.write(", ")
else:
comma = True
self.traverse(e)
self.write(")")
def visit_Subscript(self, node):
self.traverse(node.value)
self.write("[")
self.traverse(node.slice)
self.write("]")
def visit_Starred(self, node):
self.write("*")
self.traverse(node.value)
def visit_Ellipsis(self, node):
self.write("...")
def visit_Index(self, node):
self.traverse(node.value)
def visit_Slice(self, node):
if node.lower:
self.traverse(node.lower)
self.write(":")
if node.upper:
self.traverse(node.upper)
if node.step:
self.write(":")
self.traverse(node.step)
def visit_ExtSlice(self, node):
self.interleave(lambda: self.write(", "), self.traverse, node.dims)
def visit_arg(self, node):
self.write(node.arg)
if node.annotation:
self.write(": ")
self.traverse(node.annotation)
def visit_arguments(self, node):
first = True
# normal arguments
all_args = node.posonlyargs + node.args
defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults
for index, elements in enumerate(zip(all_args, defaults), 1):
a, d = elements
if first:
first = False
else:
self.write(", ")
self.traverse(a)
if d:
self.write("=")
self.traverse(d)
if index == len(node.posonlyargs):
self.write(", /")
# varargs, or bare '*' if no varargs but keyword-only arguments present
if node.vararg or node.kwonlyargs:
if first:
first = False
else:
self.write(", ")
self.write("*")
if node.vararg:
self.write(node.vararg.arg)
if node.vararg.annotation:
self.write(": ")
self.traverse(node.vararg.annotation)
# keyword-only arguments
if node.kwonlyargs:
for a, d in zip(node.kwonlyargs, node.kw_defaults):
if first:
first = False
else:
self.write(", ")
self.traverse(a),
if d:
self.write("=")
self.traverse(d)
# kwargs
if node.kwarg:
if first:
first = False
else:
self.write(", ")
self.write("**" + node.kwarg.arg)
if node.kwarg.annotation:
self.write(": ")
self.traverse(node.kwarg.annotation)
def visit_keyword(self, node):
if node.arg is None:
self.write("**")
else:
self.write(node.arg)
self.write("=")
self.traverse(node.value)
def visit_Lambda(self, node):
self.write("(")
self.write("lambda ")
self.traverse(node.args)
self.write(": ")
self.traverse(node.body)
self.write(")")
def visit_alias(self, node):
self.write(node.name)
if node.asname:
self.write(" as " + node.asname)
def visit_withitem(self, node):
self.traverse(node.context_expr)
if node.optional_vars:
self.write(" as ")
self.traverse(node.optional_vars)
def unparse(ast_obj):
unparser = _Unparser()
return unparser.visit(ast_obj)
def main(): def main():
import argparse import argparse

View file

@ -3,19 +3,12 @@
import unittest import unittest
import test.support import test.support
import io import io
import os import pathlib
import random import random
import tokenize import tokenize
import ast import ast
import functools
from test.test_tools import basepath, toolsdir, skip_if_missing
skip_if_missing()
parser_path = os.path.join(toolsdir, "parser")
with test.support.DirsOnSysPath(parser_path):
import unparse
def read_pyfile(filename): def read_pyfile(filename):
"""Read and return the contents of a Python source file (as a """Read and return the contents of a Python source file (as a
@ -26,6 +19,7 @@ def read_pyfile(filename):
source = pyfile.read() source = pyfile.read()
return source return source
for_else = """\ for_else = """\
def f(): def f():
for x in range(10): for x in range(10):
@ -119,18 +113,21 @@ with f() as x, g() as y:
suite1 suite1
""" """
class ASTTestCase(unittest.TestCase): class ASTTestCase(unittest.TestCase):
def assertASTEqual(self, ast1, ast2): def assertASTEqual(self, ast1, ast2):
self.assertEqual(ast.dump(ast1), ast.dump(ast2)) self.assertEqual(ast.dump(ast1), ast.dump(ast2))
def check_roundtrip(self, code1, filename="internal"): def check_roundtrip(self, code1):
ast1 = compile(code1, filename, "exec", ast.PyCF_ONLY_AST) ast1 = ast.parse(code1)
unparse_buffer = io.StringIO() code2 = ast.unparse(ast1)
unparse.Unparser(ast1, unparse_buffer) ast2 = ast.parse(code2)
code2 = unparse_buffer.getvalue()
ast2 = compile(code2, filename, "exec", ast.PyCF_ONLY_AST)
self.assertASTEqual(ast1, ast2) self.assertASTEqual(ast1, ast2)
def check_invalid(self, node, raises=ValueError):
self.assertRaises(raises, ast.unparse, node)
class UnparseTestCase(ASTTestCase): class UnparseTestCase(ASTTestCase):
# Tests for specific bugs found in earlier versions of unparse # Tests for specific bugs found in earlier versions of unparse
@ -174,8 +171,8 @@ class UnparseTestCase(ASTTestCase):
self.check_roundtrip("-1e1000j") self.check_roundtrip("-1e1000j")
def test_min_int(self): def test_min_int(self):
self.check_roundtrip(str(-2**31)) self.check_roundtrip(str(-(2 ** 31)))
self.check_roundtrip(str(-2**63)) self.check_roundtrip(str(-(2 ** 63)))
def test_imaginary_literals(self): def test_imaginary_literals(self):
self.check_roundtrip("7j") self.check_roundtrip("7j")
@ -265,54 +262,67 @@ class UnparseTestCase(ASTTestCase):
self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""") self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
def test_invalid_raise(self):
self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
def test_invalid_fstring_constant(self):
self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)]))
def test_invalid_fstring_conversion(self):
self.check_invalid(
ast.FormattedValue(
value=ast.Constant(value="a", kind=None),
conversion=ord("Y"), # random character
format_spec=None,
)
)
def test_invalid_set(self):
self.check_invalid(ast.Set(elts=[]))
class DirectoryTestCase(ASTTestCase): class DirectoryTestCase(ASTTestCase):
"""Test roundtrip behaviour on all files in Lib and Lib/test.""" """Test roundtrip behaviour on all files in Lib and Lib/test."""
NAMES = None
# test directories, relative to the root of the distribution lib_dir = pathlib.Path(__file__).parent / ".."
test_directories = 'Lib', os.path.join('Lib', 'test') test_directories = (lib_dir, lib_dir / "test")
skip_files = {"test_fstring.py"}
@classmethod @functools.cached_property
def get_names(cls): def files_to_test(self):
if cls.NAMES is not None: # bpo-31174: Use cached_property to store the names sample
return cls.NAMES # to always test the same files. It prevents false alarms
# when hunting reference leaks.
names = [] items = [
for d in cls.test_directories: item.resolve()
test_dir = os.path.join(basepath, d) for directory in self.test_directories
for n in os.listdir(test_dir): for item in directory.glob("*.py")
if n.endswith('.py') and not n.startswith('bad'): if not item.name.startswith("bad")
names.append(os.path.join(test_dir, n)) ]
# Test limited subset of files unless the 'cpu' resource is specified. # Test limited subset of files unless the 'cpu' resource is specified.
if not test.support.is_resource_enabled("cpu"): if not test.support.is_resource_enabled("cpu"):
names = random.sample(names, 10) items = random.sample(items, 10)
# bpo-31174: Store the names sample to always test the same files. return items
# It prevents false alarms when hunting reference leaks.
cls.NAMES = names
return names
def test_files(self): def test_files(self):
# get names of files to test for item in self.files_to_test:
names = self.get_names()
for filename in names:
if test.support.verbose: if test.support.verbose:
print('Testing %s' % filename) print(f"Testing {item.absolute()}")
# Some f-strings are not correctly round-tripped by # Some f-strings are not correctly round-tripped by
# Tools/parser/unparse.py. See issue 28002 for details. # Tools/parser/unparse.py. See issue 28002 for details.
# We need to skip files that contain such f-strings. # We need to skip files that contain such f-strings.
if os.path.basename(filename) in ('test_fstring.py', ): if item.name in self.skip_files:
if test.support.verbose: if test.support.verbose:
print(f'Skipping {filename}: see issue 28002') print(f"Skipping {item.absolute()}: see issue 28002")
continue continue
with self.subTest(filename=filename): with self.subTest(filename=item):
source = read_pyfile(filename) source = read_pyfile(item)
self.check_roundtrip(source) self.check_roundtrip(source)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -0,0 +1,4 @@
Expose :func:`ast.unparse` as a function of the :mod:`ast` module that can
be used to unparse an :class:`ast.AST` object and produce a string with code
that would produce an equivalent :class:`ast.AST` object when parsed. Patch
by Pablo Galindo and Batuhan Taskaya.

View file

@ -1,704 +0,0 @@
"Usage: unparse.py <path to source file>"
import sys
import ast
import tokenize
import io
import os
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between.
"""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
class Unparser:
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded. """
def __init__(self, tree, file = sys.stdout):
"""Unparser(tree, file=sys.stdout) -> None.
Print the source for tree to file."""
self.f = file
self._indent = 0
self.dispatch(tree)
print("", file=self.f)
self.f.flush()
def fill(self, text = ""):
"Indent a piece of text, according to the current indentation level"
self.f.write("\n"+" "*self._indent + text)
def write(self, text):
"Append a piece of text to the current line."
self.f.write(text)
def enter(self):
"Print ':', and increase the indentation."
self.write(":")
self._indent += 1
def leave(self):
"Decrease the indentation level."
self._indent -= 1
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_"+tree.__class__.__name__)
meth(tree)
############### Unparsing methods ######################
# There should be one method per concrete grammar type #
# Constructors should be grouped by sum type. Ideally, #
# this would follow the order in the grammar, but #
# currently doesn't. #
########################################################
def _Module(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
# stmt
def _Expr(self, tree):
self.fill()
self.dispatch(tree.value)
def _NamedExpr(self, tree):
self.write("(")
self.dispatch(tree.target)
self.write(" := ")
self.dispatch(tree.value)
self.write(")")
def _Import(self, t):
self.fill("import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _ImportFrom(self, t):
self.fill("from ")
self.write("." * t.level)
if t.module:
self.write(t.module)
self.write(" import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _Assign(self, t):
self.fill()
for target in t.targets:
self.dispatch(target)
self.write(" = ")
self.dispatch(t.value)
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" "+self.binop[t.op.__class__.__name__]+"= ")
self.dispatch(t.value)
def _AnnAssign(self, t):
self.fill()
if not t.simple and isinstance(t.target, ast.Name):
self.write('(')
self.dispatch(t.target)
if not t.simple and isinstance(t.target, ast.Name):
self.write(')')
self.write(": ")
self.dispatch(t.annotation)
if t.value:
self.write(" = ")
self.dispatch(t.value)
def _Return(self, t):
self.fill("return")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Pass(self, t):
self.fill("pass")
def _Break(self, t):
self.fill("break")
def _Continue(self, t):
self.fill("continue")
def _Delete(self, t):
self.fill("del ")
interleave(lambda: self.write(", "), self.dispatch, t.targets)
def _Assert(self, t):
self.fill("assert ")
self.dispatch(t.test)
if t.msg:
self.write(", ")
self.dispatch(t.msg)
def _Global(self, t):
self.fill("global ")
interleave(lambda: self.write(", "), self.write, t.names)
def _Nonlocal(self, t):
self.fill("nonlocal ")
interleave(lambda: self.write(", "), self.write, t.names)
def _Await(self, t):
self.write("(")
self.write("await")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _YieldFrom(self, t):
self.write("(")
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Raise(self, t):
self.fill("raise")
if not t.exc:
assert not t.cause
return
self.write(" ")
self.dispatch(t.exc)
if t.cause:
self.write(" from ")
self.dispatch(t.cause)
def _Try(self, t):
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
for ex in t.handlers:
self.dispatch(ex)
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
if t.finalbody:
self.fill("finally")
self.enter()
self.dispatch(t.finalbody)
self.leave()
def _ExceptHandler(self, t):
self.fill("except")
if t.type:
self.write(" ")
self.dispatch(t.type)
if t.name:
self.write(" as ")
self.write(t.name)
self.enter()
self.dispatch(t.body)
self.leave()
def _ClassDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("class "+t.name)
self.write("(")
comma = False
for e in t.bases:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
def _FunctionDef(self, t):
self.__FunctionDef_helper(t, "def")
def _AsyncFunctionDef(self, t):
self.__FunctionDef_helper(t, "async def")
def __FunctionDef_helper(self, t, fill_suffix):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
def_str = fill_suffix+" "+t.name + "("
self.fill(def_str)
self.dispatch(t.args)
self.write(")")
if t.returns:
self.write(" -> ")
self.dispatch(t.returns)
self.enter()
self.dispatch(t.body)
self.leave()
def _For(self, t):
self.__For_helper("for ", t)
def _AsyncFor(self, t):
self.__For_helper("async for ", t)
def __For_helper(self, fill, t):
self.fill(fill)
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _If(self, t):
self.fill("if ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs.
while (t.orelse and len(t.orelse) == 1 and
isinstance(t.orelse[0], ast.If)):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# final else
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _While(self, t):
self.fill("while ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _With(self, t):
self.fill("with ")
interleave(lambda: self.write(", "), self.dispatch, t.items)
self.enter()
self.dispatch(t.body)
self.leave()
def _AsyncWith(self, t):
self.fill("async with ")
interleave(lambda: self.write(", "), self.dispatch, t.items)
self.enter()
self.dispatch(t.body)
self.leave()
# expr
def _JoinedStr(self, t):
self.write("f")
string = io.StringIO()
self._fstring_JoinedStr(t, string.write)
self.write(repr(string.getvalue()))
def _FormattedValue(self, t):
self.write("f")
string = io.StringIO()
self._fstring_FormattedValue(t, string.write)
self.write(repr(string.getvalue()))
def _fstring_JoinedStr(self, t, write):
for value in t.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
def _fstring_Constant(self, t, write):
assert isinstance(t.value, str)
value = t.value.replace("{", "{{").replace("}", "}}")
write(value)
def _fstring_FormattedValue(self, t, write):
write("{")
expr = io.StringIO()
Unparser(t.value, expr)
expr = expr.getvalue().rstrip("\n")
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
write(expr)
if t.conversion != -1:
conversion = chr(t.conversion)
assert conversion in "sra"
write(f"!{conversion}")
if t.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(t.format_spec).__name__)
meth(t.format_spec, write)
write("}")
def _Name(self, t):
self.write(t.id)
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(value).replace("inf", INFSTR))
else:
self.write(repr(value))
def _Constant(self, t):
value = t.value
if isinstance(value, tuple):
self.write("(")
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is ...:
self.write("...")
else:
if t.kind == "u":
self.write("u")
self._write_constant(t.value)
def _List(self, t):
self.write("[")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _comprehension(self, t):
if t.is_async:
self.write(" async for ")
else:
self.write(" for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
for if_clause in t.ifs:
self.write(" if ")
self.dispatch(if_clause)
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
def _Set(self, t):
assert(t.elts) # should be at least one element
self.write("{")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
def _Dict(self, t):
self.write("{")
def write_key_value_pair(k, v):
self.dispatch(k)
self.write(": ")
self.dispatch(v)
def write_item(item):
k, v = item
if k is None:
# for dictionary unpacking operator in dicts {**{'y': 2}}
# see PEP 448 for details
self.write("**")
self.dispatch(v)
else:
write_key_value_pair(k, v)
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(" ")
self.dispatch(t.operand)
self.write(")")
binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%",
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&",
"FloorDiv":"//", "Pow": "**"}
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=",
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"}
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
boolops = {ast.And: 'and', ast.Or: 'or'}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
def _Attribute(self,t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, ast.Constant) and isinstance(t.value.value, int):
self.write(" ")
self.write(".")
self.write(t.attr)
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
for e in t.args:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
self.write(")")
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.write("]")
def _Starred(self, t):
self.write("*")
self.dispatch(t.value)
# slice
def _Ellipsis(self, t):
self.write("...")
def _Index(self, t):
self.dispatch(t.value)
def _Slice(self, t):
if t.lower:
self.dispatch(t.lower)
self.write(":")
if t.upper:
self.dispatch(t.upper)
if t.step:
self.write(":")
self.dispatch(t.step)
def _ExtSlice(self, t):
interleave(lambda: self.write(', '), self.dispatch, t.dims)
# argument
def _arg(self, t):
self.write(t.arg)
if t.annotation:
self.write(": ")
self.dispatch(t.annotation)
# others
def _arguments(self, t):
first = True
# normal arguments
all_args = t.posonlyargs + t.args
defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults
for index, elements in enumerate(zip(all_args, defaults), 1):
a, d = elements
if first:first = False
else: self.write(", ")
self.dispatch(a)
if d:
self.write("=")
self.dispatch(d)
if index == len(t.posonlyargs):
self.write(", /")
# varargs, or bare '*' if no varargs but keyword-only arguments present
if t.vararg or t.kwonlyargs:
if first:first = False
else: self.write(", ")
self.write("*")
if t.vararg:
self.write(t.vararg.arg)
if t.vararg.annotation:
self.write(": ")
self.dispatch(t.vararg.annotation)
# keyword-only arguments
if t.kwonlyargs:
for a, d in zip(t.kwonlyargs, t.kw_defaults):
if first:first = False
else: self.write(", ")
self.dispatch(a),
if d:
self.write("=")
self.dispatch(d)
# kwargs
if t.kwarg:
if first:first = False
else: self.write(", ")
self.write("**"+t.kwarg.arg)
if t.kwarg.annotation:
self.write(": ")
self.dispatch(t.kwarg.annotation)
def _keyword(self, t):
if t.arg is None:
self.write("**")
else:
self.write(t.arg)
self.write("=")
self.dispatch(t.value)
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
def _alias(self, t):
self.write(t.name)
if t.asname:
self.write(" as "+t.asname)
def _withitem(self, t):
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
def roundtrip(filename, output=sys.stdout):
with open(filename, "rb") as pyfile:
encoding = tokenize.detect_encoding(pyfile.readline)[0]
with open(filename, "r", encoding=encoding) as pyfile:
source = pyfile.read()
tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST)
Unparser(tree, output)
def testdir(a):
try:
names = [n for n in os.listdir(a) if n.endswith('.py')]
except OSError:
print("Directory not readable: %s" % a, file=sys.stderr)
else:
for n in names:
fullname = os.path.join(a, n)
if os.path.isfile(fullname):
output = io.StringIO()
print('Testing %s' % fullname)
try:
roundtrip(fullname, output)
except Exception as e:
print(' Failed to compile, exception is %s' % repr(e))
elif os.path.isdir(fullname):
testdir(fullname)
def main(args):
if args[0] == '--testdir':
for a in args[1:]:
testdir(a)
else:
for a in args:
roundtrip(a)
if __name__=='__main__':
main(sys.argv[1:])