gh-129598: ast: allow multi stmts for ast single with ';' (#129620)

This commit is contained in:
Tomasz Pytel 2025-03-19 18:29:40 -04:00 committed by GitHub
parent 20c5f969dd
commit a8cb5e4a43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 164 additions and 30 deletions

View file

@ -142,13 +142,13 @@ class ASTTestCase(ASTTestMixin, unittest.TestCase):
with self.subTest(node=node):
self.assertRaises(raises, ast.unparse, node)
def get_source(self, code1, code2=None):
def get_source(self, code1, code2=None, **kwargs):
code2 = code2 or code1
code1 = ast.unparse(ast.parse(code1))
code1 = ast.unparse(ast.parse(code1, **kwargs))
return code1, code2
def check_src_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2)
def check_src_roundtrip(self, code1, code2=None, **kwargs):
code1, code2 = self.get_source(code1, code2, **kwargs)
with self.subTest(code1=code1, code2=code2):
self.assertEqual(code2, code1)
@ -469,6 +469,120 @@ class UnparseTestCase(ASTTestCase):
):
self.check_ast_roundtrip(statement, type_comments=True)
def test_unparse_interactive_semicolons(self):
# gh-129598: Fix ast.unparse() when ast.Interactive contains multiple statements
self.check_src_roundtrip("i = 1; 'expr'; raise Exception", mode='single')
self.check_src_roundtrip("i: int = 1; j: float = 0; k += l", mode='single')
combinable = (
"'expr'",
"(i := 1)",
"import foo",
"from foo import bar",
"i = 1",
"i += 1",
"i: int = 1",
"return i",
"pass",
"break",
"continue",
"del i",
"assert i",
"global i",
"nonlocal j",
"await i",
"yield i",
"yield from i",
"raise i",
"type t[T] = ...",
"i",
)
for a in combinable:
for b in combinable:
self.check_src_roundtrip(f"{a}; {b}", mode='single')
def test_unparse_interactive_integrity_1(self):
# rest of unparse_interactive_integrity tests just make sure mode='single' parse and unparse didn't break
self.check_src_roundtrip(
"if i:\n 'expr'\nelse:\n raise Exception",
"if i:\n 'expr'\nelse:\n raise Exception",
mode='single'
)
self.check_src_roundtrip(
"@decorator1\n@decorator2\ndef func():\n 'docstring'\n i = 1; 'expr'; raise Exception",
'''@decorator1\n@decorator2\ndef func():\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
mode='single'
)
self.check_src_roundtrip(
"@decorator1\n@decorator2\nclass cls:\n 'docstring'\n i = 1; 'expr'; raise Exception",
'''@decorator1\n@decorator2\nclass cls:\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
mode='single'
)
def test_unparse_interactive_integrity_2(self):
for statement in (
"def x():\n pass",
"def x(y):\n pass",
"async def x():\n pass",
"async def x(y):\n pass",
"for x in y:\n pass",
"async for x in y:\n pass",
"with x():\n pass",
"async with x():\n pass",
"def f():\n pass",
"def f(a):\n pass",
"def f(b=2):\n pass",
"def f(a, b):\n pass",
"def f(a, b=2):\n pass",
"def f(a=5, b=2):\n pass",
"def f(*, a=1, b=2):\n pass",
"def f(*, a=1, b):\n pass",
"def f(*, a, b=2):\n pass",
"def f(a, b=None, *, c, **kwds):\n pass",
"def f(a=2, *args, c=5, d, **kwds):\n pass",
"def f(*args, **kwargs):\n pass",
"class cls:\n\n def f(self):\n pass",
"class cls:\n\n def f(self, a):\n pass",
"class cls:\n\n def f(self, b=2):\n pass",
"class cls:\n\n def f(self, a, b):\n pass",
"class cls:\n\n def f(self, a, b=2):\n pass",
"class cls:\n\n def f(self, a=5, b=2):\n pass",
"class cls:\n\n def f(self, *, a=1, b=2):\n pass",
"class cls:\n\n def f(self, *, a=1, b):\n pass",
"class cls:\n\n def f(self, *, a, b=2):\n pass",
"class cls:\n\n def f(self, a, b=None, *, c, **kwds):\n pass",
"class cls:\n\n def f(self, a=2, *args, c=5, d, **kwds):\n pass",
"class cls:\n\n def f(self, *args, **kwargs):\n pass",
):
self.check_src_roundtrip(statement, mode='single')
def test_unparse_interactive_integrity_3(self):
for statement in (
"def x():",
"def x(y):",
"async def x():",
"async def x(y):",
"for x in y:",
"async for x in y:",
"with x():",
"async with x():",
"def f():",
"def f(a):",
"def f(b=2):",
"def f(a, b):",
"def f(a, b=2):",
"def f(a=5, b=2):",
"def f(*, a=1, b=2):",
"def f(*, a=1, b):",
"def f(*, a, b=2):",
"def f(a, b=None, *, c, **kwds):",
"def f(a=2, *args, c=5, d, **kwds):",
"def f(*args, **kwargs):",
):
src = statement + '\n i=1;j=2'
out = statement + '\n i = 1\n j = 2'
self.check_src_roundtrip(src, out, mode='single')
class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecessary additions"""