diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..173a0e4 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -17,7 +17,9 @@ __all__ = [ "BaseAction", "InsertAfter", + "InsertBefore", "LazyInsertAfter", + "LazyInsertBefore", "LazyReplace", "Replace", "Erase", @@ -195,11 +197,51 @@ def _stack_effect(self) -> tuple[ast.AST, int]: return (self.node, 1) +@dataclass +class LazyInsertBefore(_LazyActionMixin[ast.stmt, ast.stmt]): + """Inserts the re-synthesized version :py:meth:`LazyInsertBefore.build`'s + output right before the given `node`. + + .. note:: + Subclasses of :py:class:`LazyInsertBefore` must override + :py:meth:`LazyInsertBefore.build`. + + .. note:: + This action requires both the `node` and the built target to be statements. + """ + + def apply(self, context: Context, source: str) -> str: + lines = split_lines(source, encoding=context.file_info.get_encoding()) + indentation, start_prefix = find_indent( + lines[self.node.lineno - 1][: self.node.col_offset] + ) + + replacement = split_lines(context.unparse(self.build())) + replacement.apply_indentation(indentation, start_prefix=start_prefix) + replacement[-1] += lines._newline_type + + original_node_start = cast(int, self.node.lineno) + for line in reversed(replacement): + lines.insert(original_node_start - 1, line) + + return lines.join() + + def _stack_effect(self) -> tuple[ast.AST, int]: + # Adding a statement right before the node will need to be reflected + # in the block. + return (self.node, -1) + + @dataclass class NewStatementAction(LazyInsertAfter, _DeprecatedAliasMixin): ... +@dataclass +class NewStatementBeforeAction(LazyInsertBefore, _DeprecatedAliasMixin): + ... + + @_hint("deprecated_alias", "TargetedNewStatementAction") @dataclass class InsertAfter(LazyInsertAfter): @@ -216,11 +258,31 @@ def build(self) -> ast.stmt: return self.target +@dataclass +class InsertBefore(LazyInsertBefore): + """Inserts the re-synthesized version of given `target` right after + the given `node`. + + .. note:: + This action requires both the `node` and `target` to be a statements. + """ + + target: ast.stmt + + def build(self) -> ast.stmt: + return self.target + + @dataclass class TargetedNewStatementAction(InsertAfter, _DeprecatedAliasMixin): ... +@dataclass +class TargetedNewStatementBeforeAction(InsertBefore, _DeprecatedAliasMixin): + ... + + @dataclass class _Rename(Replace): identifier_span: PositionType @@ -272,7 +334,7 @@ def _resynthesize(self, context: Context) -> str: return "" def _stack_effect(self) -> tuple[ast.AST, int]: - # Erasing a single node mean positions of all the followinng statements will + # Erasing a single node mean positions of all the following statements will # need to reduced by 1. return (self.node, -1) diff --git a/refactor/internal/graph_access.py b/refactor/internal/graph_access.py index 0543fe3..fec7dba 100644 --- a/refactor/internal/graph_access.py +++ b/refactor/internal/graph_access.py @@ -135,14 +135,14 @@ def shift(self, shifts: list[tuple[GraphPath, int]]) -> GraphPath: assert isinstance(target_access, IndexAccess) # This change might affect the future nodes in this path - # but not us. - if shifter.index >= target_access.index: + # When the requested shift_offset is negative, it does affect us + # by the amount of offset + if shifter.index + shift_offset >= target_access.index: continue parts[parts.index(target_access)] = target_access.replace( - index=target_access.index + shift_offset + index=target_access.index + abs(shift_offset) ) - return GraphPath(parts) def execute(self, node: ast.AST) -> ast.AST: diff --git a/tests/test_actions.py b/tests/test_actions.py index 7d39a43..a8040be 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -1,11 +1,17 @@ from __future__ import annotations import ast +import textwrap +from pathlib import Path +from typing import Iterator, cast import pytest +from refactor.ast import DEFAULT_ENCODING -from refactor.actions import Erase, InvalidActionError +from refactor import Session, common +from refactor.actions import Erase, InvalidActionError, InsertAfter, Replace, InsertBefore from refactor.context import Context +from refactor.core import Rule INVALID_ERASES = """ def foo(): @@ -42,6 +48,433 @@ def foo(): INVALID_ERASES_TREE = ast.parse(INVALID_ERASES) +class TestInsertAfterBottom(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + await async_test()""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + await_st = ast.parse("await async_test()") + yield InsertAfter(node, cast(ast.stmt, await_st)) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + +class TestInsertBeforeTop(Rule): + INPUT_SOURCE = """ + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + await async_test() + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + await_st = ast.parse("await async_test()") + yield InsertBefore(node, cast(ast.stmt, await_st)) + new_try = common.clone(node) + new_try.body = [node.body[0]] + yield Replace(node, cast(ast.AST, new_try)) + + +class TestInsertAfter(Rule): + INPUT_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + active_tree = get_tree(active_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + active_tree = get_tree(active_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[InsertAfter]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + new_trys = [] + for stmt in node.body: + new_try = common.clone(node) + new_try.body = [stmt] + new_trys.append(new_try) + + first_try, *remaining_trys = new_trys + yield Replace(node, first_try) + for remaining_try in reversed(remaining_trys): + yield InsertAfter(node, remaining_try) + + +class TestInsertBefore(Rule): + INPUT_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + active_tree = get_tree(active_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + def generate_index(base_path, active_path): + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + active_tree = get_tree(active_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + new_trys = [] + for stmt in node.body: + new_try = common.clone(node) + new_try.body = [stmt] + new_trys.append(new_try) + + first_try, *remaining_trys = new_trys + yield Replace(node, first_try) + for remaining_try in remaining_trys: + yield InsertBefore(node, remaining_try) + + +class TestInsertAfterThenBefore(Rule): + INPUT_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + def generate_index(base_path, active_path): + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter | InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + new_trys = [] + for stmt in node.body: + new_try = common.clone(node) + new_try.body = [stmt] + new_trys.append(new_try) + + first_try, *remaining_trys = new_trys + yield Replace(node, first_try) + for remaining_try in reversed(remaining_trys): + yield InsertAfter(node, remaining_try) + for remaining_try in remaining_trys: + yield InsertBefore(node, remaining_try) + + +class TestInsertBeforeThenAfterBothReversed(Rule): + INPUT_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + def generate_index(base_path, active_path): + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter | InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + new_trys = [] + for stmt in node.body: + new_try = common.clone(node) + new_try.body = [stmt] + new_trys.append(new_try) + + first_try, *remaining_trys = new_trys + yield Replace(node, first_try) + for remaining_try in reversed(remaining_trys): + yield InsertBefore(node, remaining_try) + for remaining_try in reversed(remaining_trys): + yield InsertAfter(node, remaining_try) + + +class TestInsertAfterBeforeRepeat(Rule): + INPUT_SOURCE = """ + def generate_index(base_path, active_path): + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + second_tree = get_tree(second_tree, module_name) + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + base_tree = get_tree(base_file, module_name) + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + EXPECTED_SOURCE = """ + def generate_index(base_path, active_path): + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + second_tree = get_tree(second_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + third_tree = get_tree(third_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + + print('processing ', module_name) + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + base_tree = get_tree(base_file, module_name) + except (SyntaxError, FileNotFoundError): + continue + try: + first_tree = get_tree(first_tree, module_name) + except (SyntaxError, FileNotFoundError): + continue""" + + def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter | InsertBefore]: + assert isinstance(node, ast.Try) + assert len(node.body) >= 2 + + new_trys = [] + for stmt in node.body: + new_try = common.clone(node) + new_try.body = [stmt] + new_trys.append(new_try) + + first_try, *remaining_trys = new_trys + yield Replace(node, first_try) + # It is important to note that we reversed the changes + # This can have a counter-intuitive expectation of the results + # Possibly, a less confusing testcase could be implemented ;P + for remaining_try in reversed(remaining_trys[:3]): + yield InsertBefore(node, remaining_try) + yield InsertAfter(node, remaining_try) + + @pytest.mark.parametrize( "invalid_node", [node for node in ast.walk(INVALID_ERASES_TREE) if isinstance(node, ast.Assert)], @@ -50,3 +483,38 @@ def test_erase_invalid(invalid_node): context = Context(INVALID_ERASES, INVALID_ERASES_TREE) with pytest.raises(InvalidActionError): Erase(invalid_node).apply(context, INVALID_ERASES) + + +@pytest.mark.parametrize( + "rule", + [ + TestInsertAfterBottom, + TestInsertBeforeTop, + TestInsertAfter, + TestInsertBefore, + TestInsertAfterThenBefore, + TestInsertBeforeThenAfterBothReversed, + TestInsertAfterBeforeRepeat, + ], +) +def test_rules(rule, tmp_path): + session = Session([rule]) + + source_code = textwrap.dedent(rule.INPUT_SOURCE) + try: + ast.parse(source_code) + except SyntaxError: + pytest.fail("Input source is not valid Python code") + + assert session.run(source_code) == textwrap.dedent(rule.EXPECTED_SOURCE) + + src_file_path = Path(tmp_path / rule.__name__.lower()).with_suffix(".py") + src_file_path.write_text(source_code, encoding=DEFAULT_ENCODING) + + change = session.run_file(src_file_path) + assert change is not None + + change.apply_diff() + assert src_file_path.read_text(encoding=DEFAULT_ENCODING) == textwrap.dedent( + rule.EXPECTED_SOURCE + ) diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index b46c14e..dfd4984 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -14,6 +14,8 @@ from refactor.actions import ( Erase, EraseOrReplace, + InsertBefore, + LazyInsertBefore, InsertAfter, LazyInsertAfter, LazyReplace, @@ -183,6 +185,19 @@ def build(self): ) +@dataclass +class AddNewImportBefore(LazyInsertBefore): + module: str + names: list[str] + + def build(self): + return ast.ImportFrom( + level=0, + module=self.module, + names=[ast.alias(name) for name in self.names], + ) + + @dataclass class ModifyExistingImport(LazyReplace): name: str @@ -253,6 +268,66 @@ def match(self, node): return ModifyExistingImport(closest_import, node.id) +class TypingAutoImporterBefore(Rule): + INPUT_SOURCE = """ + import lol + from something import another + + def foo(items: List[Optional[str]]) -> Dict[str, List[Tuple[int, ...]]]: + class Something: + no: Iterable[int] + + def bar(self, context: Dict[str, int]) -> List[int]: + print(1) + """ + + EXPECTED_SOURCE = """ + import lol + from typing import Dict, List, Iterable, Optional, Tuple + from something import another + + def foo(items: List[Optional[str]]) -> Dict[str, List[Tuple[int, ...]]]: + class Something: + no: Iterable[int] + + def bar(self, context: Dict[str, int]) -> List[int]: + print(1) + """ + + context_providers = (ImportFinder, context.Scope) + + def find_last_import(self, tree): + assert isinstance(tree, ast.Module) + for index, node in enumerate(tree.body, -1): + if isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant): + continue + elif isinstance(node, (ast.Import, ast.ImportFrom)): + continue + else: + break + + return tree.body[index] + + def match(self, node): + assert isinstance(node, ast.Name) + assert isinstance(node.ctx, ast.Load) + assert node.id in typing.__all__ + assert not node.id.startswith("__") + + scope = self.context["scope"].resolve(node) + typing_imports = self.context["import_finder"].collect("typing", scope=scope) + + if len(typing_imports) == 0: + last_import = self.find_last_import(self.context.tree) + return AddNewImportBefore(last_import, "typing", [node.id]) + + assert len(typing_imports) >= 1 + assert node.id not in typing_imports + + closest_import = common.find_closest(node, *typing_imports.values()) + return ModifyExistingImport(closest_import, node.id) + + class AsyncifierAction(LazyReplace): def build(self): new_node = self.branch() @@ -948,6 +1023,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: ReplacePlaceholders, PropagateConstants, TypingAutoImporter, + TypingAutoImporterBefore, MakeFunctionAsync, OnlyKeywordArgumentDefaultNotSetCheckRule, InternalizeFunctions,