mirror of
https://github.com/python/cpython.git
synced 2025-08-04 00:48:58 +00:00
gh-100518: Add tests for ast.NodeTransformer
(#100521)
This commit is contained in:
parent
f63f525e16
commit
c1c5882359
3 changed files with 171 additions and 42 deletions
43
Lib/test/support/ast_helper.py
Normal file
43
Lib/test/support/ast_helper.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import ast
|
||||
|
||||
class ASTTestMixin:
|
||||
"""Test mixing to have basic assertions for AST nodes."""
|
||||
|
||||
def assertASTEqual(self, ast1, ast2):
|
||||
# Ensure the comparisons start at an AST node
|
||||
self.assertIsInstance(ast1, ast.AST)
|
||||
self.assertIsInstance(ast2, ast.AST)
|
||||
|
||||
# An AST comparison routine modeled after ast.dump(), but
|
||||
# instead of string building, it traverses the two trees
|
||||
# in lock-step.
|
||||
def traverse_compare(a, b, missing=object()):
|
||||
if type(a) is not type(b):
|
||||
self.fail(f"{type(a)!r} is not {type(b)!r}")
|
||||
if isinstance(a, ast.AST):
|
||||
for field in a._fields:
|
||||
value1 = getattr(a, field, missing)
|
||||
value2 = getattr(b, field, missing)
|
||||
# Singletons are equal by definition, so further
|
||||
# testing can be skipped.
|
||||
if value1 is not value2:
|
||||
traverse_compare(value1, value2)
|
||||
elif isinstance(a, list):
|
||||
try:
|
||||
for node1, node2 in zip(a, b, strict=True):
|
||||
traverse_compare(node1, node2)
|
||||
except ValueError:
|
||||
# Attempt a "pretty" error ala assertSequenceEqual()
|
||||
len1 = len(a)
|
||||
len2 = len(b)
|
||||
if len1 > len2:
|
||||
what = "First"
|
||||
diff = len1 - len2
|
||||
else:
|
||||
what = "Second"
|
||||
diff = len2 - len1
|
||||
msg = f"{what} list contains {diff} additional elements."
|
||||
raise self.failureException(msg) from None
|
||||
elif a != b:
|
||||
self.fail(f"{a!r} != {b!r}")
|
||||
traverse_compare(ast1, ast2)
|
|
@ -11,6 +11,7 @@ import weakref
|
|||
from textwrap import dedent
|
||||
|
||||
from test import support
|
||||
from test.support.ast_helper import ASTTestMixin
|
||||
|
||||
def to_tuple(t):
|
||||
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
|
||||
|
@ -2290,9 +2291,10 @@ class EndPositionTests(unittest.TestCase):
|
|||
self.assertIsNone(ast.get_source_segment(s, x))
|
||||
self.assertIsNone(ast.get_source_segment(s, y))
|
||||
|
||||
class NodeVisitorTests(unittest.TestCase):
|
||||
class BaseNodeVisitorCases:
|
||||
# Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
|
||||
def test_old_constant_nodes(self):
|
||||
class Visitor(ast.NodeVisitor):
|
||||
class Visitor(self.visitor_class):
|
||||
def visit_Num(self, node):
|
||||
log.append((node.lineno, 'Num', node.n))
|
||||
def visit_Str(self, node):
|
||||
|
@ -2340,6 +2342,128 @@ class NodeVisitorTests(unittest.TestCase):
|
|||
])
|
||||
|
||||
|
||||
class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
|
||||
visitor_class = ast.NodeVisitor
|
||||
|
||||
|
||||
class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
|
||||
visitor_class = ast.NodeTransformer
|
||||
|
||||
def assertASTTransformation(self, tranformer_class,
|
||||
initial_code, expected_code):
|
||||
initial_ast = ast.parse(dedent(initial_code))
|
||||
expected_ast = ast.parse(dedent(expected_code))
|
||||
|
||||
tranformer = tranformer_class()
|
||||
result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))
|
||||
|
||||
self.assertASTEqual(result_ast, expected_ast)
|
||||
|
||||
def test_node_remove_single(self):
|
||||
code = 'def func(arg) -> SomeType: ...'
|
||||
expected = 'def func(arg): ...'
|
||||
|
||||
# Since `FunctionDef.returns` is defined as a single value, we test
|
||||
# the `if isinstance(old_value, AST):` branch here.
|
||||
class SomeTypeRemover(ast.NodeTransformer):
|
||||
def visit_Name(self, node: ast.Name):
|
||||
self.generic_visit(node)
|
||||
if node.id == 'SomeType':
|
||||
return None
|
||||
return node
|
||||
|
||||
self.assertASTTransformation(SomeTypeRemover, code, expected)
|
||||
|
||||
def test_node_remove_from_list(self):
|
||||
code = """
|
||||
def func(arg):
|
||||
print(arg)
|
||||
yield arg
|
||||
"""
|
||||
expected = """
|
||||
def func(arg):
|
||||
print(arg)
|
||||
"""
|
||||
|
||||
# Since `FunctionDef.body` is defined as a list, we test
|
||||
# the `if isinstance(old_value, list):` branch here.
|
||||
class YieldRemover(ast.NodeTransformer):
|
||||
def visit_Expr(self, node: ast.Expr):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, ast.Yield):
|
||||
return None # Remove `yield` from a function
|
||||
return node
|
||||
|
||||
self.assertASTTransformation(YieldRemover, code, expected)
|
||||
|
||||
def test_node_return_list(self):
|
||||
code = """
|
||||
class DSL(Base, kw1=True): ...
|
||||
"""
|
||||
expected = """
|
||||
class DSL(Base, kw1=True, kw2=True, kw3=False): ...
|
||||
"""
|
||||
|
||||
class ExtendKeywords(ast.NodeTransformer):
|
||||
def visit_keyword(self, node: ast.keyword):
|
||||
self.generic_visit(node)
|
||||
if node.arg == 'kw1':
|
||||
return [
|
||||
node,
|
||||
ast.keyword('kw2', ast.Constant(True)),
|
||||
ast.keyword('kw3', ast.Constant(False)),
|
||||
]
|
||||
return node
|
||||
|
||||
self.assertASTTransformation(ExtendKeywords, code, expected)
|
||||
|
||||
def test_node_mutate(self):
|
||||
code = """
|
||||
def func(arg):
|
||||
print(arg)
|
||||
"""
|
||||
expected = """
|
||||
def func(arg):
|
||||
log(arg)
|
||||
"""
|
||||
|
||||
class PrintToLog(ast.NodeTransformer):
|
||||
def visit_Call(self, node: ast.Call):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'print':
|
||||
node.func.id = 'log'
|
||||
return node
|
||||
|
||||
self.assertASTTransformation(PrintToLog, code, expected)
|
||||
|
||||
def test_node_replace(self):
|
||||
code = """
|
||||
def func(arg):
|
||||
print(arg)
|
||||
"""
|
||||
expected = """
|
||||
def func(arg):
|
||||
logger.log(arg, debug=True)
|
||||
"""
|
||||
|
||||
class PrintToLog(ast.NodeTransformer):
|
||||
def visit_Call(self, node: ast.Call):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'print':
|
||||
return ast.Call(
|
||||
func=ast.Attribute(
|
||||
ast.Name('logger', ctx=ast.Load()),
|
||||
attr='log',
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=node.args,
|
||||
keywords=[ast.keyword('debug', ast.Constant(True))],
|
||||
)
|
||||
return node
|
||||
|
||||
self.assertASTTransformation(PrintToLog, code, expected)
|
||||
|
||||
|
||||
@support.cpython_only
|
||||
class ModuleStateTests(unittest.TestCase):
|
||||
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
|
||||
|
|
|
@ -6,6 +6,7 @@ import pathlib
|
|||
import random
|
||||
import tokenize
|
||||
import ast
|
||||
from test.support.ast_helper import ASTTestMixin
|
||||
|
||||
|
||||
def read_pyfile(filename):
|
||||
|
@ -128,46 +129,7 @@ docstring_prefixes = (
|
|||
"async def foo():\n ",
|
||||
)
|
||||
|
||||
class ASTTestCase(unittest.TestCase):
|
||||
def assertASTEqual(self, ast1, ast2):
|
||||
# Ensure the comparisons start at an AST node
|
||||
self.assertIsInstance(ast1, ast.AST)
|
||||
self.assertIsInstance(ast2, ast.AST)
|
||||
|
||||
# An AST comparison routine modeled after ast.dump(), but
|
||||
# instead of string building, it traverses the two trees
|
||||
# in lock-step.
|
||||
def traverse_compare(a, b, missing=object()):
|
||||
if type(a) is not type(b):
|
||||
self.fail(f"{type(a)!r} is not {type(b)!r}")
|
||||
if isinstance(a, ast.AST):
|
||||
for field in a._fields:
|
||||
value1 = getattr(a, field, missing)
|
||||
value2 = getattr(b, field, missing)
|
||||
# Singletons are equal by definition, so further
|
||||
# testing can be skipped.
|
||||
if value1 is not value2:
|
||||
traverse_compare(value1, value2)
|
||||
elif isinstance(a, list):
|
||||
try:
|
||||
for node1, node2 in zip(a, b, strict=True):
|
||||
traverse_compare(node1, node2)
|
||||
except ValueError:
|
||||
# Attempt a "pretty" error ala assertSequenceEqual()
|
||||
len1 = len(a)
|
||||
len2 = len(b)
|
||||
if len1 > len2:
|
||||
what = "First"
|
||||
diff = len1 - len2
|
||||
else:
|
||||
what = "Second"
|
||||
diff = len2 - len1
|
||||
msg = f"{what} list contains {diff} additional elements."
|
||||
raise self.failureException(msg) from None
|
||||
elif a != b:
|
||||
self.fail(f"{a!r} != {b!r}")
|
||||
traverse_compare(ast1, ast2)
|
||||
|
||||
class ASTTestCase(ASTTestMixin, unittest.TestCase):
|
||||
def check_ast_roundtrip(self, code1, **kwargs):
|
||||
with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
|
||||
ast1 = ast.parse(code1, **kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue