|
11 | 11 | from textwrap import dedent
|
12 | 12 |
|
13 | 13 | from test import support
|
| 14 | +from test.support.ast_helper import ASTTestMixin |
14 | 15 |
|
15 | 16 | def to_tuple(t):
|
16 | 17 | if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
|
@@ -2290,9 +2291,10 @@ def test_source_segment_missing_info(self):
|
2290 | 2291 | self.assertIsNone(ast.get_source_segment(s, x))
|
2291 | 2292 | self.assertIsNone(ast.get_source_segment(s, y))
|
2292 | 2293 |
|
2293 |
| -class NodeVisitorTests(unittest.TestCase): |
| 2294 | +class BaseNodeVisitorCases: |
| 2295 | + # Both `NodeVisitor` and `NodeTranformer` must raise these warnings: |
2294 | 2296 | def test_old_constant_nodes(self):
|
2295 |
| - class Visitor(ast.NodeVisitor): |
| 2297 | + class Visitor(self.visitor_class): |
2296 | 2298 | def visit_Num(self, node):
|
2297 | 2299 | log.append((node.lineno, 'Num', node.n))
|
2298 | 2300 | def visit_Str(self, node):
|
@@ -2340,6 +2342,128 @@ def visit_Ellipsis(self, node):
|
2340 | 2342 | ])
|
2341 | 2343 |
|
2342 | 2344 |
|
| 2345 | +class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase): |
| 2346 | + visitor_class = ast.NodeVisitor |
| 2347 | + |
| 2348 | + |
| 2349 | +class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase): |
| 2350 | + visitor_class = ast.NodeTransformer |
| 2351 | + |
| 2352 | + def assertASTTransformation(self, tranformer_class, |
| 2353 | + initial_code, expected_code): |
| 2354 | + initial_ast = ast.parse(dedent(initial_code)) |
| 2355 | + expected_ast = ast.parse(dedent(expected_code)) |
| 2356 | + |
| 2357 | + tranformer = tranformer_class() |
| 2358 | + result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast)) |
| 2359 | + |
| 2360 | + self.assertASTEqual(result_ast, expected_ast) |
| 2361 | + |
| 2362 | + def test_node_remove_single(self): |
| 2363 | + code = 'def func(arg) -> SomeType: ...' |
| 2364 | + expected = 'def func(arg): ...' |
| 2365 | + |
| 2366 | + # Since `FunctionDef.returns` is defined as a single value, we test |
| 2367 | + # the `if isinstance(old_value, AST):` branch here. |
| 2368 | + class SomeTypeRemover(ast.NodeTransformer): |
| 2369 | + def visit_Name(self, node: ast.Name): |
| 2370 | + self.generic_visit(node) |
| 2371 | + if node.id == 'SomeType': |
| 2372 | + return None |
| 2373 | + return node |
| 2374 | + |
| 2375 | + self.assertASTTransformation(SomeTypeRemover, code, expected) |
| 2376 | + |
| 2377 | + def test_node_remove_from_list(self): |
| 2378 | + code = """ |
| 2379 | + def func(arg): |
| 2380 | + print(arg) |
| 2381 | + yield arg |
| 2382 | + """ |
| 2383 | + expected = """ |
| 2384 | + def func(arg): |
| 2385 | + print(arg) |
| 2386 | + """ |
| 2387 | + |
| 2388 | + # Since `FunctionDef.body` is defined as a list, we test |
| 2389 | + # the `if isinstance(old_value, list):` branch here. |
| 2390 | + class YieldRemover(ast.NodeTransformer): |
| 2391 | + def visit_Expr(self, node: ast.Expr): |
| 2392 | + self.generic_visit(node) |
| 2393 | + if isinstance(node.value, ast.Yield): |
| 2394 | + return None # Remove `yield` from a function |
| 2395 | + return node |
| 2396 | + |
| 2397 | + self.assertASTTransformation(YieldRemover, code, expected) |
| 2398 | + |
| 2399 | + def test_node_return_list(self): |
| 2400 | + code = """ |
| 2401 | + class DSL(Base, kw1=True): ... |
| 2402 | + """ |
| 2403 | + expected = """ |
| 2404 | + class DSL(Base, kw1=True, kw2=True, kw3=False): ... |
| 2405 | + """ |
| 2406 | + |
| 2407 | + class ExtendKeywords(ast.NodeTransformer): |
| 2408 | + def visit_keyword(self, node: ast.keyword): |
| 2409 | + self.generic_visit(node) |
| 2410 | + if node.arg == 'kw1': |
| 2411 | + return [ |
| 2412 | + node, |
| 2413 | + ast.keyword('kw2', ast.Constant(True)), |
| 2414 | + ast.keyword('kw3', ast.Constant(False)), |
| 2415 | + ] |
| 2416 | + return node |
| 2417 | + |
| 2418 | + self.assertASTTransformation(ExtendKeywords, code, expected) |
| 2419 | + |
| 2420 | + def test_node_mutate(self): |
| 2421 | + code = """ |
| 2422 | + def func(arg): |
| 2423 | + print(arg) |
| 2424 | + """ |
| 2425 | + expected = """ |
| 2426 | + def func(arg): |
| 2427 | + log(arg) |
| 2428 | + """ |
| 2429 | + |
| 2430 | + class PrintToLog(ast.NodeTransformer): |
| 2431 | + def visit_Call(self, node: ast.Call): |
| 2432 | + self.generic_visit(node) |
| 2433 | + if isinstance(node.func, ast.Name) and node.func.id == 'print': |
| 2434 | + node.func.id = 'log' |
| 2435 | + return node |
| 2436 | + |
| 2437 | + self.assertASTTransformation(PrintToLog, code, expected) |
| 2438 | + |
| 2439 | + def test_node_replace(self): |
| 2440 | + code = """ |
| 2441 | + def func(arg): |
| 2442 | + print(arg) |
| 2443 | + """ |
| 2444 | + expected = """ |
| 2445 | + def func(arg): |
| 2446 | + logger.log(arg, debug=True) |
| 2447 | + """ |
| 2448 | + |
| 2449 | + class PrintToLog(ast.NodeTransformer): |
| 2450 | + def visit_Call(self, node: ast.Call): |
| 2451 | + self.generic_visit(node) |
| 2452 | + if isinstance(node.func, ast.Name) and node.func.id == 'print': |
| 2453 | + return ast.Call( |
| 2454 | + func=ast.Attribute( |
| 2455 | + ast.Name('logger', ctx=ast.Load()), |
| 2456 | + attr='log', |
| 2457 | + ctx=ast.Load(), |
| 2458 | + ), |
| 2459 | + args=node.args, |
| 2460 | + keywords=[ast.keyword('debug', ast.Constant(True))], |
| 2461 | + ) |
| 2462 | + return node |
| 2463 | + |
| 2464 | + self.assertASTTransformation(PrintToLog, code, expected) |
| 2465 | + |
| 2466 | + |
2343 | 2467 | @support.cpython_only
|
2344 | 2468 | class ModuleStateTests(unittest.TestCase):
|
2345 | 2469 | # bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
|
|
0 commit comments