bpo-38870: Don't start generated output with newlines in ast.unparse (GH-19636)

This commit is contained in:
Batuhan Taskaya 2020-05-03 20:11:51 +03:00 committed by GitHub
parent 3dd2157feb
commit 493bf1cc31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 11 deletions

View file

@ -669,10 +669,16 @@ class _Unparser(NodeVisitor):
else: else:
self.interleave(lambda: self.write(", "), traverser, items) self.interleave(lambda: self.write(", "), traverser, items)
def maybe_newline(self):
"""Adds a newline if it isn't the start of generated source"""
if self._source:
self.write("\n")
def fill(self, text=""): def fill(self, text=""):
"""Indent a piece of text and append it, according to the current """Indent a piece of text and append it, according to the current
indentation level""" indentation level"""
self.write("\n" + " " * self._indent + text) self.maybe_newline()
self.write(" " * self._indent + text)
def write(self, text): def write(self, text):
"""Append a piece of text""" """Append a piece of text"""
@ -916,7 +922,7 @@ class _Unparser(NodeVisitor):
self.traverse(node.body) self.traverse(node.body)
def visit_ClassDef(self, node): def visit_ClassDef(self, node):
self.write("\n") self.maybe_newline()
for deco in node.decorator_list: for deco in node.decorator_list:
self.fill("@") self.fill("@")
self.traverse(deco) self.traverse(deco)
@ -946,7 +952,7 @@ class _Unparser(NodeVisitor):
self._function_helper(node, "async def") self._function_helper(node, "async def")
def _function_helper(self, node, fill_suffix): def _function_helper(self, node, fill_suffix):
self.write("\n") self.maybe_newline()
for deco in node.decorator_list: for deco in node.decorator_list:
self.fill("@") self.fill("@")
self.traverse(deco) self.traverse(deco)
@ -1043,7 +1049,7 @@ class _Unparser(NodeVisitor):
write("{") write("{")
unparser = type(self)() unparser = type(self)()
unparser.set_precedence(_Precedence.TEST.next(), node.value) unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value).rstrip("\n") expr = unparser.visit(node.value)
if expr.startswith("{"): if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {" write(" ") # Separate pair of opening brackets as "{ {"
write(expr) write(expr)

View file

@ -128,19 +128,17 @@ class ASTTestCase(unittest.TestCase):
def check_invalid(self, node, raises=ValueError): def check_invalid(self, node, raises=ValueError):
self.assertRaises(raises, ast.unparse, node) self.assertRaises(raises, ast.unparse, node)
def get_source(self, code1, code2=None, strip=True): def get_source(self, code1, code2=None):
code2 = code2 or code1 code2 = code2 or code1
code1 = ast.unparse(ast.parse(code1)) code1 = ast.unparse(ast.parse(code1))
if strip:
code1 = code1.strip()
return code1, code2 return code1, code2
def check_src_roundtrip(self, code1, code2=None, strip=True): def check_src_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2, strip) code1, code2 = self.get_source(code1, code2)
self.assertEqual(code2, code1) self.assertEqual(code2, code1)
def check_src_dont_roundtrip(self, code1, code2=None, strip=True): def check_src_dont_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2, strip) code1, code2 = self.get_source(code1, code2)
self.assertNotEqual(code2, code1) self.assertNotEqual(code2, code1)
class UnparseTestCase(ASTTestCase): class UnparseTestCase(ASTTestCase):