Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c1c5882

Browse files
authored
gh-100518: Add tests for ast.NodeTransformer (#100521)
1 parent f63f525 commit c1c5882

File tree

3 files changed

+171
-42
lines changed

3 files changed

+171
-42
lines changed

Lib/test/support/ast_helper.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import ast
2+
3+
class ASTTestMixin:
4+
"""Test mixing to have basic assertions for AST nodes."""
5+
6+
def assertASTEqual(self, ast1, ast2):
7+
# Ensure the comparisons start at an AST node
8+
self.assertIsInstance(ast1, ast.AST)
9+
self.assertIsInstance(ast2, ast.AST)
10+
11+
# An AST comparison routine modeled after ast.dump(), but
12+
# instead of string building, it traverses the two trees
13+
# in lock-step.
14+
def traverse_compare(a, b, missing=object()):
15+
if type(a) is not type(b):
16+
self.fail(f"{type(a)!r} is not {type(b)!r}")
17+
if isinstance(a, ast.AST):
18+
for field in a._fields:
19+
value1 = getattr(a, field, missing)
20+
value2 = getattr(b, field, missing)
21+
# Singletons are equal by definition, so further
22+
# testing can be skipped.
23+
if value1 is not value2:
24+
traverse_compare(value1, value2)
25+
elif isinstance(a, list):
26+
try:
27+
for node1, node2 in zip(a, b, strict=True):
28+
traverse_compare(node1, node2)
29+
except ValueError:
30+
# Attempt a "pretty" error ala assertSequenceEqual()
31+
len1 = len(a)
32+
len2 = len(b)
33+
if len1 > len2:
34+
what = "First"
35+
diff = len1 - len2
36+
else:
37+
what = "Second"
38+
diff = len2 - len1
39+
msg = f"{what} list contains {diff} additional elements."
40+
raise self.failureException(msg) from None
41+
elif a != b:
42+
self.fail(f"{a!r} != {b!r}")
43+
traverse_compare(ast1, ast2)

Lib/test/test_ast.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from textwrap import dedent
1212

1313
from test import support
14+
from test.support.ast_helper import ASTTestMixin
1415

1516
def to_tuple(t):
1617
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
@@ -2290,9 +2291,10 @@ def test_source_segment_missing_info(self):
22902291
self.assertIsNone(ast.get_source_segment(s, x))
22912292
self.assertIsNone(ast.get_source_segment(s, y))
22922293

2293-
class NodeVisitorTests(unittest.TestCase):
2294+
class BaseNodeVisitorCases:
2295+
# Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
22942296
def test_old_constant_nodes(self):
2295-
class Visitor(ast.NodeVisitor):
2297+
class Visitor(self.visitor_class):
22962298
def visit_Num(self, node):
22972299
log.append((node.lineno, 'Num', node.n))
22982300
def visit_Str(self, node):
@@ -2340,6 +2342,128 @@ def visit_Ellipsis(self, node):
23402342
])
23412343

23422344

2345+
class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
2346+
visitor_class = ast.NodeVisitor
2347+
2348+
2349+
class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
2350+
visitor_class = ast.NodeTransformer
2351+
2352+
def assertASTTransformation(self, tranformer_class,
2353+
initial_code, expected_code):
2354+
initial_ast = ast.parse(dedent(initial_code))
2355+
expected_ast = ast.parse(dedent(expected_code))
2356+
2357+
tranformer = tranformer_class()
2358+
result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))
2359+
2360+
self.assertASTEqual(result_ast, expected_ast)
2361+
2362+
def test_node_remove_single(self):
2363+
code = 'def func(arg) -> SomeType: ...'
2364+
expected = 'def func(arg): ...'
2365+
2366+
# Since `FunctionDef.returns` is defined as a single value, we test
2367+
# the `if isinstance(old_value, AST):` branch here.
2368+
class SomeTypeRemover(ast.NodeTransformer):
2369+
def visit_Name(self, node: ast.Name):
2370+
self.generic_visit(node)
2371+
if node.id == 'SomeType':
2372+
return None
2373+
return node
2374+
2375+
self.assertASTTransformation(SomeTypeRemover, code, expected)
2376+
2377+
def test_node_remove_from_list(self):
2378+
code = """
2379+
def func(arg):
2380+
print(arg)
2381+
yield arg
2382+
"""
2383+
expected = """
2384+
def func(arg):
2385+
print(arg)
2386+
"""
2387+
2388+
# Since `FunctionDef.body` is defined as a list, we test
2389+
# the `if isinstance(old_value, list):` branch here.
2390+
class YieldRemover(ast.NodeTransformer):
2391+
def visit_Expr(self, node: ast.Expr):
2392+
self.generic_visit(node)
2393+
if isinstance(node.value, ast.Yield):
2394+
return None # Remove `yield` from a function
2395+
return node
2396+
2397+
self.assertASTTransformation(YieldRemover, code, expected)
2398+
2399+
def test_node_return_list(self):
2400+
code = """
2401+
class DSL(Base, kw1=True): ...
2402+
"""
2403+
expected = """
2404+
class DSL(Base, kw1=True, kw2=True, kw3=False): ...
2405+
"""
2406+
2407+
class ExtendKeywords(ast.NodeTransformer):
2408+
def visit_keyword(self, node: ast.keyword):
2409+
self.generic_visit(node)
2410+
if node.arg == 'kw1':
2411+
return [
2412+
node,
2413+
ast.keyword('kw2', ast.Constant(True)),
2414+
ast.keyword('kw3', ast.Constant(False)),
2415+
]
2416+
return node
2417+
2418+
self.assertASTTransformation(ExtendKeywords, code, expected)
2419+
2420+
def test_node_mutate(self):
2421+
code = """
2422+
def func(arg):
2423+
print(arg)
2424+
"""
2425+
expected = """
2426+
def func(arg):
2427+
log(arg)
2428+
"""
2429+
2430+
class PrintToLog(ast.NodeTransformer):
2431+
def visit_Call(self, node: ast.Call):
2432+
self.generic_visit(node)
2433+
if isinstance(node.func, ast.Name) and node.func.id == 'print':
2434+
node.func.id = 'log'
2435+
return node
2436+
2437+
self.assertASTTransformation(PrintToLog, code, expected)
2438+
2439+
def test_node_replace(self):
2440+
code = """
2441+
def func(arg):
2442+
print(arg)
2443+
"""
2444+
expected = """
2445+
def func(arg):
2446+
logger.log(arg, debug=True)
2447+
"""
2448+
2449+
class PrintToLog(ast.NodeTransformer):
2450+
def visit_Call(self, node: ast.Call):
2451+
self.generic_visit(node)
2452+
if isinstance(node.func, ast.Name) and node.func.id == 'print':
2453+
return ast.Call(
2454+
func=ast.Attribute(
2455+
ast.Name('logger', ctx=ast.Load()),
2456+
attr='log',
2457+
ctx=ast.Load(),
2458+
),
2459+
args=node.args,
2460+
keywords=[ast.keyword('debug', ast.Constant(True))],
2461+
)
2462+
return node
2463+
2464+
self.assertASTTransformation(PrintToLog, code, expected)
2465+
2466+
23432467
@support.cpython_only
23442468
class ModuleStateTests(unittest.TestCase):
23452469
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.

Lib/test/test_unparse.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import tokenize
88
import ast
9+
from test.support.ast_helper import ASTTestMixin
910

1011

1112
def read_pyfile(filename):
@@ -128,46 +129,7 @@ class Foo: pass
128129
"async def foo():\n ",
129130
)
130131

131-
class ASTTestCase(unittest.TestCase):
132-
def assertASTEqual(self, ast1, ast2):
133-
# Ensure the comparisons start at an AST node
134-
self.assertIsInstance(ast1, ast.AST)
135-
self.assertIsInstance(ast2, ast.AST)
136-
137-
# An AST comparison routine modeled after ast.dump(), but
138-
# instead of string building, it traverses the two trees
139-
# in lock-step.
140-
def traverse_compare(a, b, missing=object()):
141-
if type(a) is not type(b):
142-
self.fail(f"{type(a)!r} is not {type(b)!r}")
143-
if isinstance(a, ast.AST):
144-
for field in a._fields:
145-
value1 = getattr(a, field, missing)
146-
value2 = getattr(b, field, missing)
147-
# Singletons are equal by definition, so further
148-
# testing can be skipped.
149-
if value1 is not value2:
150-
traverse_compare(value1, value2)
151-
elif isinstance(a, list):
152-
try:
153-
for node1, node2 in zip(a, b, strict=True):
154-
traverse_compare(node1, node2)
155-
except ValueError:
156-
# Attempt a "pretty" error ala assertSequenceEqual()
157-
len1 = len(a)
158-
len2 = len(b)
159-
if len1 > len2:
160-
what = "First"
161-
diff = len1 - len2
162-
else:
163-
what = "Second"
164-
diff = len2 - len1
165-
msg = f"{what} list contains {diff} additional elements."
166-
raise self.failureException(msg) from None
167-
elif a != b:
168-
self.fail(f"{a!r} != {b!r}")
169-
traverse_compare(ast1, ast2)
170-
132+
class ASTTestCase(ASTTestMixin, unittest.TestCase):
171133
def check_ast_roundtrip(self, code1, **kwargs):
172134
with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
173135
ast1 = ast.parse(code1, **kwargs)

0 commit comments

Comments
 (0)