gh-100518: Add tests for ast.NodeTransformer (#100521)

This commit is contained in:
Nikita Sobolev 2023-01-22 00:44:41 +03:00 committed by GitHub
parent f63f525e16
commit c1c5882359
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 42 deletions

View file

@ -11,6 +11,7 @@ import weakref
from textwrap import dedent
from test import support
from test.support.ast_helper import ASTTestMixin
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
@ -2290,9 +2291,10 @@ class EndPositionTests(unittest.TestCase):
self.assertIsNone(ast.get_source_segment(s, x))
self.assertIsNone(ast.get_source_segment(s, y))
class NodeVisitorTests(unittest.TestCase):
class BaseNodeVisitorCases:
# Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
def test_old_constant_nodes(self):
class Visitor(ast.NodeVisitor):
class Visitor(self.visitor_class):
def visit_Num(self, node):
log.append((node.lineno, 'Num', node.n))
def visit_Str(self, node):
@ -2340,6 +2342,128 @@ class NodeVisitorTests(unittest.TestCase):
])
class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
visitor_class = ast.NodeVisitor
class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
visitor_class = ast.NodeTransformer
def assertASTTransformation(self, tranformer_class,
initial_code, expected_code):
initial_ast = ast.parse(dedent(initial_code))
expected_ast = ast.parse(dedent(expected_code))
tranformer = tranformer_class()
result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))
self.assertASTEqual(result_ast, expected_ast)
def test_node_remove_single(self):
code = 'def func(arg) -> SomeType: ...'
expected = 'def func(arg): ...'
# Since `FunctionDef.returns` is defined as a single value, we test
# the `if isinstance(old_value, AST):` branch here.
class SomeTypeRemover(ast.NodeTransformer):
def visit_Name(self, node: ast.Name):
self.generic_visit(node)
if node.id == 'SomeType':
return None
return node
self.assertASTTransformation(SomeTypeRemover, code, expected)
def test_node_remove_from_list(self):
code = """
def func(arg):
print(arg)
yield arg
"""
expected = """
def func(arg):
print(arg)
"""
# Since `FunctionDef.body` is defined as a list, we test
# the `if isinstance(old_value, list):` branch here.
class YieldRemover(ast.NodeTransformer):
def visit_Expr(self, node: ast.Expr):
self.generic_visit(node)
if isinstance(node.value, ast.Yield):
return None # Remove `yield` from a function
return node
self.assertASTTransformation(YieldRemover, code, expected)
def test_node_return_list(self):
code = """
class DSL(Base, kw1=True): ...
"""
expected = """
class DSL(Base, kw1=True, kw2=True, kw3=False): ...
"""
class ExtendKeywords(ast.NodeTransformer):
def visit_keyword(self, node: ast.keyword):
self.generic_visit(node)
if node.arg == 'kw1':
return [
node,
ast.keyword('kw2', ast.Constant(True)),
ast.keyword('kw3', ast.Constant(False)),
]
return node
self.assertASTTransformation(ExtendKeywords, code, expected)
def test_node_mutate(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
log(arg)
"""
class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
node.func.id = 'log'
return node
self.assertASTTransformation(PrintToLog, code, expected)
def test_node_replace(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
logger.log(arg, debug=True)
"""
class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
return ast.Call(
func=ast.Attribute(
ast.Name('logger', ctx=ast.Load()),
attr='log',
ctx=ast.Load(),
),
args=node.args,
keywords=[ast.keyword('debug', ast.Constant(True))],
)
return node
self.assertASTTransformation(PrintToLog, code, expected)
@support.cpython_only
class ModuleStateTests(unittest.TestCase):
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.