Thanks to visit codestin.com
Credit goes to github.com

Skip to content

gh-100518: Add tests for ast.NodeTransformer #100521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions Lib/test/support/ast_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import ast

class ASTTestMixin:
"""Test mixing to have basic assertions for AST nodes."""

def assertASTEqual(self, ast1, ast2):
# Ensure the comparisons start at an AST node
self.assertIsInstance(ast1, ast.AST)
self.assertIsInstance(ast2, ast.AST)

# An AST comparison routine modeled after ast.dump(), but
# instead of string building, it traverses the two trees
# in lock-step.
def traverse_compare(a, b, missing=object()):
if type(a) is not type(b):
self.fail(f"{type(a)!r} is not {type(b)!r}")
if isinstance(a, ast.AST):
for field in a._fields:
value1 = getattr(a, field, missing)
value2 = getattr(b, field, missing)
# Singletons are equal by definition, so further
# testing can be skipped.
if value1 is not value2:
traverse_compare(value1, value2)
elif isinstance(a, list):
try:
for node1, node2 in zip(a, b, strict=True):
traverse_compare(node1, node2)
except ValueError:
# Attempt a "pretty" error ala assertSequenceEqual()
len1 = len(a)
len2 = len(b)
if len1 > len2:
what = "First"
diff = len1 - len2
else:
what = "Second"
diff = len2 - len1
msg = f"{what} list contains {diff} additional elements."
raise self.failureException(msg) from None
elif a != b:
self.fail(f"{a!r} != {b!r}")
traverse_compare(ast1, ast2)
128 changes: 126 additions & 2 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from textwrap import dedent

from test import support
from test.support.ast_helper import ASTTestMixin

def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
Expand Down Expand Up @@ -2297,9 +2298,10 @@ def test_source_segment_missing_info(self):
self.assertIsNone(ast.get_source_segment(s, x))
self.assertIsNone(ast.get_source_segment(s, y))

class NodeVisitorTests(unittest.TestCase):
class BaseNodeVisitorCases:
# Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
def test_old_constant_nodes(self):
class Visitor(ast.NodeVisitor):
class Visitor(self.visitor_class):
def visit_Num(self, node):
log.append((node.lineno, 'Num', node.n))
def visit_Str(self, node):
Expand Down Expand Up @@ -2347,6 +2349,128 @@ def visit_Ellipsis(self, node):
])


class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
visitor_class = ast.NodeVisitor


class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
visitor_class = ast.NodeTransformer

def assertASTTransformation(self, tranformer_class,
initial_code, expected_code):
initial_ast = ast.parse(dedent(initial_code))
expected_ast = ast.parse(dedent(expected_code))

tranformer = tranformer_class()
result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))

self.assertASTEqual(result_ast, expected_ast)

def test_node_remove_single(self):
code = 'def func(arg) -> SomeType: ...'
expected = 'def func(arg): ...'

# Since `FunctionDef.returns` is defined as a single value, we test
# the `if isinstance(old_value, AST):` branch here.
class SomeTypeRemover(ast.NodeTransformer):
def visit_Name(self, node: ast.Name):
self.generic_visit(node)
if node.id == 'SomeType':
return None
return node

self.assertASTTransformation(SomeTypeRemover, code, expected)

def test_node_remove_from_list(self):
code = """
def func(arg):
print(arg)
yield arg
"""
expected = """
def func(arg):
print(arg)
"""

# Since `FunctionDef.body` is defined as a list, we test
# the `if isinstance(old_value, list):` branch here.
class YieldRemover(ast.NodeTransformer):
def visit_Expr(self, node: ast.Expr):
self.generic_visit(node)
if isinstance(node.value, ast.Yield):
return None # Remove `yield` from a function
return node

self.assertASTTransformation(YieldRemover, code, expected)

def test_node_return_list(self):
code = """
class DSL(Base, kw1=True): ...
"""
expected = """
class DSL(Base, kw1=True, kw2=True, kw3=False): ...
"""

class ExtendKeywords(ast.NodeTransformer):
def visit_keyword(self, node: ast.keyword):
self.generic_visit(node)
if node.arg == 'kw1':
return [
node,
ast.keyword('kw2', ast.Constant(True)),
ast.keyword('kw3', ast.Constant(False)),
]
return node

self.assertASTTransformation(ExtendKeywords, code, expected)

def test_node_mutate(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
log(arg)
"""

class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
node.func.id = 'log'
return node

self.assertASTTransformation(PrintToLog, code, expected)

def test_node_replace(self):
code = """
def func(arg):
print(arg)
"""
expected = """
def func(arg):
logger.log(arg, debug=True)
"""

class PrintToLog(ast.NodeTransformer):
def visit_Call(self, node: ast.Call):
self.generic_visit(node)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
return ast.Call(
func=ast.Attribute(
ast.Name('logger', ctx=ast.Load()),
attr='log',
ctx=ast.Load(),
),
args=node.args,
keywords=[ast.keyword('debug', ast.Constant(True))],
)
return node

self.assertASTTransformation(PrintToLog, code, expected)


@support.cpython_only
class ModuleStateTests(unittest.TestCase):
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
Expand Down
42 changes: 2 additions & 40 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import tokenize
import ast
from test.support.ast_helper import ASTTestMixin


def read_pyfile(filename):
Expand Down Expand Up @@ -128,46 +129,7 @@ class Foo: pass
"async def foo():\n ",
)

class ASTTestCase(unittest.TestCase):
def assertASTEqual(self, ast1, ast2):
# Ensure the comparisons start at an AST node
self.assertIsInstance(ast1, ast.AST)
self.assertIsInstance(ast2, ast.AST)

# An AST comparison routine modeled after ast.dump(), but
# instead of string building, it traverses the two trees
# in lock-step.
def traverse_compare(a, b, missing=object()):
if type(a) is not type(b):
self.fail(f"{type(a)!r} is not {type(b)!r}")
if isinstance(a, ast.AST):
for field in a._fields:
value1 = getattr(a, field, missing)
value2 = getattr(b, field, missing)
# Singletons are equal by definition, so further
# testing can be skipped.
if value1 is not value2:
traverse_compare(value1, value2)
elif isinstance(a, list):
try:
for node1, node2 in zip(a, b, strict=True):
traverse_compare(node1, node2)
except ValueError:
# Attempt a "pretty" error ala assertSequenceEqual()
len1 = len(a)
len2 = len(b)
if len1 > len2:
what = "First"
diff = len1 - len2
else:
what = "Second"
diff = len2 - len1
msg = f"{what} list contains {diff} additional elements."
raise self.failureException(msg) from None
elif a != b:
self.fail(f"{a!r} != {b!r}")
traverse_compare(ast1, ast2)

class ASTTestCase(ASTTestMixin, unittest.TestCase):
def check_ast_roundtrip(self, code1, **kwargs):
with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
ast1 = ast.parse(code1, **kwargs)
Expand Down