From 25b3760ff0a7a14a2015e0376adc693b068e9d10 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 13:28:49 +0200 Subject: [PATCH 1/8] cache visit_ methods of a VisitorFinder on class level --- src/robot/parsing/model/visitor.py | 58 ++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index fb176459fb5..037e6923b79 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -14,27 +14,49 @@ # limitations under the License. import ast +from typing import Any, Callable, Dict, Optional, Type, Union from .statements import Node +class _NotSet: + pass + + class VisitorFinder: + __NOT_SET = _NotSet() + __cls_cache: Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] - def _find_visitor(self, cls): - if cls is ast.AST: + def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] + # create cache on class level to avoid creating it for each instance + cls.__cls_cache = {} + return super().__new__(cls) + + @classmethod + def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: + if node_cls is ast.AST: return None - method = 'visit_' + cls.__name__ - if hasattr(self, method): - return getattr(self, method) - # Forward-compatibility. - if method == 'visit_Return' and hasattr(self, 'visit_ReturnSetting'): - return getattr(self, 'visit_ReturnSetting') - for base in cls.__bases__: - visitor = self._find_visitor(base) - if visitor: - return visitor + method_name = "visit_" + node_cls.__name__ + method = getattr(cls, method_name, None) + if callable(method): + return method # type: ignore[no-any-return] + if method_name == "visit_Return": + method = getattr(cls, "visit_ReturnSetting", None) + if callable(method): + return method # type: ignore[no-any-return] + for base in node_cls.__bases__: + method = cls._find_visitor(base) + if method: + return method return None + @classmethod + def _find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: + result = cls.__cls_cache.get(node_cls, cls.__NOT_SET) + if result is cls.__NOT_SET: + result = cls.__cls_cache[node_cls] = cls.__find_visitor(node_cls) + return result # type: ignore[return-value] + class ModelVisitor(ast.NodeVisitor, VisitorFinder): """NodeVisitor that supports matching nodes based on their base classes. @@ -49,9 +71,9 @@ def visit_Statement(self, node): ... """ - def visit(self, node: Node): - visitor = self._find_visitor(type(node)) or self.generic_visit - visitor(node) + def visit(self, node: Node) -> None: + visitor = self._find_visitor(type(node)) or self.__class__.generic_visit + visitor(self, node) class ModelTransformer(ast.NodeTransformer, VisitorFinder): @@ -62,6 +84,6 @@ class ModelTransformer(ast.NodeTransformer, VisitorFinder): `__. """ - def visit(self, node: Node): - visitor = self._find_visitor(type(node)) or self.generic_visit - return visitor(node) + def visit(self, node: Node) -> Node: + visitor = self._find_visitor(type(node)) or self.__class__.generic_visit + return visitor(self, node) From bf5b3a2147307c65fb3b22e3871c9bd8280ebe63 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 13:57:07 +0200 Subject: [PATCH 2/8] remove unneeded _field attribute in Statement The `_field` attribute is not needed in Statement because it should only contain child nodes, but `type` and `tokens` are'nt child nodes. I move the value to the _attributes because then they can be dumped. --- src/robot/parsing/model/statements.py | 2 +- utest/parsing/parsing_test_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/robot/parsing/model/statements.py b/src/robot/parsing/model/statements.py index 3a8b8c10748..d4e37850804 100644 --- a/src/robot/parsing/model/statements.py +++ b/src/robot/parsing/model/statements.py @@ -47,7 +47,7 @@ class Node(ast.AST, ABC): class Statement(Node, ABC): - _fields = ('type', 'tokens') + _attributes = (*Node._attributes, 'tokens', 'type') type: str handles_types: 'ClassVar[tuple[str, ...]]' = () statement_handlers: 'ClassVar[dict[str, Type[Statement]]]' = {} diff --git a/utest/parsing/parsing_test_utils.py b/utest/parsing/parsing_test_utils.py index 8dff1b8dc5d..84695f14374 100644 --- a/utest/parsing/parsing_test_utils.py +++ b/utest/parsing/parsing_test_utils.py @@ -47,13 +47,13 @@ def assert_block(model, expected, expected_attrs): def assert_statement(model, expected): - assert_equal(model._fields, ('type', 'tokens')) + assert_equal(model._fields, ()) assert_equal(model.type, expected.type) assert_equal(len(model.tokens), len(expected.tokens)) for m, e in zip(model.tokens, expected.tokens): assert_equal(m, e, formatter=repr) assert_equal(model._attributes, ('lineno', 'col_offset', 'end_lineno', - 'end_col_offset', 'errors')) + 'end_col_offset', 'errors', 'tokens', 'type')) assert_equal(model.lineno, expected.tokens[0].lineno) assert_equal(model.col_offset, expected.tokens[0].col_offset) assert_equal(model.end_lineno, expected.tokens[-1].lineno) From 1012143bb85f3e4f1f14a6c8966197e0d20e0964 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 16:47:47 +0200 Subject: [PATCH 3/8] introduce own optimized versions of generic_visit in ModelVisitor and ModelTransformer The original generic_visit method checks if the `_field` values are of type ast.AST and then if the values are if type list. because the `isinstance` check is releativ slow in python and because the RF Model is correctly defined and always returns `Node` or `List[Nodes]` we only need to check if the field is of type `List` --- src/robot/parsing/model/visitor.py | 72 ++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index 037e6923b79..fd061ed8416 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -14,7 +14,8 @@ # limitations under the License. import ast -from typing import Any, Callable, Dict, Optional, Type, Union +from collections import defaultdict +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union from .statements import Node @@ -25,11 +26,14 @@ class _NotSet: class VisitorFinder: __NOT_SET = _NotSet() - __cls_cache: Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] + __cls_finder_global_cache__: Dict[ + Type[Any], Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] + ] = defaultdict(dict) + __cls_finder_cache__: Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] - # create cache on class level to avoid creating it for each instance - cls.__cls_cache = {} + if not hasattr(cls, "__cls_finder_cache__"): + cls.__cls_finder_cache__ = cls.__cls_finder_global_cache__[cls] return super().__new__(cls) @classmethod @@ -52,13 +56,29 @@ def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: @classmethod def _find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: - result = cls.__cls_cache.get(node_cls, cls.__NOT_SET) + result = cls.__cls_finder_cache__.get(node_cls, cls.__NOT_SET) if result is cls.__NOT_SET: - result = cls.__cls_cache[node_cls] = cls.__find_visitor(node_cls) + result = cls.__cls_finder_cache__[node_cls] = cls.__find_visitor(node_cls) return result # type: ignore[return-value] -class ModelVisitor(ast.NodeVisitor, VisitorFinder): +def _iter_field_values(node: Node) -> Iterator[Union[Node, List[Node], None]]: + for field in node._fields: + try: + yield getattr(node, field) + except AttributeError: + pass + + +def iter_fields(node: Node) -> Iterator[Tuple[str, Union[Node, List[Node], None]]]: + for field in node._fields: + try: + yield field, getattr(node, field) + except AttributeError: + pass + + +class ModelVisitor(VisitorFinder): """NodeVisitor that supports matching nodes based on their base classes. In other ways identical to the standard `ast.NodeVisitor @@ -75,8 +95,18 @@ def visit(self, node: Node) -> None: visitor = self._find_visitor(type(node)) or self.__class__.generic_visit visitor(self, node) + def generic_visit(self, node: Node) -> None: + for value in _iter_field_values(node): + if value is None: + continue + if isinstance(value, list): + for item in value: + self.visit(item) + else: + self.visit(value) + -class ModelTransformer(ast.NodeTransformer, VisitorFinder): +class ModelTransformer(VisitorFinder): """NodeTransformer that supports matching nodes based on their base classes. See :class:`ModelVisitor` for explanation how this is different compared @@ -84,6 +114,30 @@ class ModelTransformer(ast.NodeTransformer, VisitorFinder): `__. """ - def visit(self, node: Node) -> Node: + def visit(self, node: Node) -> Union[Node, List[Node], None]: visitor = self._find_visitor(type(node)) or self.__class__.generic_visit return visitor(self, node) + + def generic_visit(self, node: Node) -> Union[Node, List[Node], None]: + for field, old_value in iter_fields(node): + if old_value is None: + continue + if isinstance(old_value, list): + new_values = [] + for value in old_value: + new_value = self.visit(value) + if new_value is None: + continue + if isinstance(new_value, list): + new_values.extend(new_value) + continue + new_values.append(new_value) + old_value[:] = new_values + else: + new_node = self.visit(old_value) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + + return node From 5940b350fb0ce2d2c755d9902502c2f0705673f7 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 17:21:54 +0200 Subject: [PATCH 4/8] some cosmetic changes in new visitor implementation --- src/robot/parsing/model/visitor.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index fd061ed8416..eb1e1cbf94e 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -14,26 +14,18 @@ # limitations under the License. import ast -from collections import defaultdict from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union +from ...utils.notset import NOT_SET, NotSet from .statements import Node -class _NotSet: - pass - - class VisitorFinder: - __NOT_SET = _NotSet() - __cls_finder_global_cache__: Dict[ - Type[Any], Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] - ] = defaultdict(dict) - __cls_finder_cache__: Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]] + __cls_finder_cache__: Dict[Type[Any], Union[Callable[..., Any], None, NotSet]] def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] if not hasattr(cls, "__cls_finder_cache__"): - cls.__cls_finder_cache__ = cls.__cls_finder_global_cache__[cls] + cls.__cls_finder_cache__ = {} return super().__new__(cls) @classmethod @@ -49,15 +41,15 @@ def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: if callable(method): return method # type: ignore[no-any-return] for base in node_cls.__bases__: - method = cls._find_visitor(base) + method = cls._find_visitor_cached(base) if method: return method return None @classmethod - def _find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: - result = cls.__cls_finder_cache__.get(node_cls, cls.__NOT_SET) - if result is cls.__NOT_SET: + def _find_visitor_cached(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: + result = cls.__cls_finder_cache__.get(node_cls, NOT_SET) + if result is NOT_SET: result = cls.__cls_finder_cache__[node_cls] = cls.__find_visitor(node_cls) return result # type: ignore[return-value] @@ -92,7 +84,7 @@ def visit_Statement(self, node): """ def visit(self, node: Node) -> None: - visitor = self._find_visitor(type(node)) or self.__class__.generic_visit + visitor = self._find_visitor_cached(type(node)) or self.__class__.generic_visit visitor(self, node) def generic_visit(self, node: Node) -> None: @@ -115,7 +107,7 @@ class ModelTransformer(VisitorFinder): """ def visit(self, node: Node) -> Union[Node, List[Node], None]: - visitor = self._find_visitor(type(node)) or self.__class__.generic_visit + visitor = self._find_visitor_cached(type(node)) or self.__class__.generic_visit return visitor(self, node) def generic_visit(self, node: Node) -> Union[Node, List[Node], None]: From 264391452637a4a08faa5f04961a1b90fe50c2b2 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 17:56:07 +0200 Subject: [PATCH 5/8] update type hints to new style in visitor classes --- src/robot/parsing/model/visitor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index eb1e1cbf94e..47deed7db21 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -14,14 +14,14 @@ # limitations under the License. import ast -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Iterator from ...utils.notset import NOT_SET, NotSet from .statements import Node class VisitorFinder: - __cls_finder_cache__: Dict[Type[Any], Union[Callable[..., Any], None, NotSet]] + __cls_finder_cache__: "dict[type[Any], Callable[..., Any]|None|NotSet]" def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] if not hasattr(cls, "__cls_finder_cache__"): @@ -29,7 +29,7 @@ def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] return super().__new__(cls) @classmethod - def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: + def __find_visitor(cls, node_cls: "type[Any]") -> "Callable[..., Any]|None": if node_cls is ast.AST: return None method_name = "visit_" + node_cls.__name__ @@ -47,14 +47,14 @@ def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: return None @classmethod - def _find_visitor_cached(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]: + def _find_visitor_cached(cls, node_cls: "type[Any]") -> "Callable[..., Any]|None": result = cls.__cls_finder_cache__.get(node_cls, NOT_SET) if result is NOT_SET: result = cls.__cls_finder_cache__[node_cls] = cls.__find_visitor(node_cls) return result # type: ignore[return-value] -def _iter_field_values(node: Node) -> Iterator[Union[Node, List[Node], None]]: +def _iter_field_values(node: Node) -> "Iterator[Node|list[Node]|None]": for field in node._fields: try: yield getattr(node, field) @@ -62,7 +62,7 @@ def _iter_field_values(node: Node) -> Iterator[Union[Node, List[Node], None]]: pass -def iter_fields(node: Node) -> Iterator[Tuple[str, Union[Node, List[Node], None]]]: +def iter_fields(node: Node) -> "Iterator[tuple[str, Node|list[Node]|None]]": for field in node._fields: try: yield field, getattr(node, field) @@ -106,11 +106,11 @@ class ModelTransformer(VisitorFinder): `__. """ - def visit(self, node: Node) -> Union[Node, List[Node], None]: + def visit(self, node: Node) -> "Node|list[Node]|None": visitor = self._find_visitor_cached(type(node)) or self.__class__.generic_visit return visitor(self, node) - def generic_visit(self, node: Node) -> Union[Node, List[Node], None]: + def generic_visit(self, node: Node) -> "Node|list[Node]|None": for field, old_value in iter_fields(node): if old_value is None: continue From c52b257066d77f924072afaf03267d90c3f0b620 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Tue, 24 Oct 2023 22:38:56 +0200 Subject: [PATCH 6/8] Cache initialization in VisitorFinder moved to __init_subclass__ and handle generic_visit in VisitorFinder, but be safe from Liskov substitution principle --- src/robot/parsing/model/visitor.py | 43 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index 47deed7db21..aa12052ae50 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -13,24 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ast -from typing import Any, Callable, Iterator +from abc import ABC +from ast import AST +from collections import defaultdict +from typing import Any, Callable, Generic, Iterator, TypeVar from ...utils.notset import NOT_SET, NotSet from .statements import Node +TVisitorResult = TypeVar("TVisitorResult") -class VisitorFinder: - __cls_finder_cache__: "dict[type[Any], Callable[..., Any]|None|NotSet]" - def __new__(cls, *_args: Any, **_kwargs: Any): # type: ignore[no-untyped-def] - if not hasattr(cls, "__cls_finder_cache__"): - cls.__cls_finder_cache__ = {} - return super().__new__(cls) +class VisitorFinder(ABC, Generic[TVisitorResult]): + __visitor_finder_cache: "dict[type[Any], Callable[[Node], Any]|None|NotSet]" + __default_visitor_method: "Callable[..., TVisitorResult]" + + def __init_subclass__(cls, default_visitor_name: str = "generic_visit", **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + cls.__visitor_finder_cache = defaultdict(lambda: NOT_SET) + cls.__default_visitor_method = getattr(cls, default_visitor_name) @classmethod - def __find_visitor(cls, node_cls: "type[Any]") -> "Callable[..., Any]|None": - if node_cls is ast.AST: + def _find_visitor_class_method(cls, node_cls: "type[Any]") -> "Callable[..., TVisitorResult]|None": + if node_cls is AST: return None method_name = "visit_" + node_cls.__name__ method = getattr(cls, method_name, None) @@ -41,17 +46,17 @@ def __find_visitor(cls, node_cls: "type[Any]") -> "Callable[..., Any]|None": if callable(method): return method # type: ignore[no-any-return] for base in node_cls.__bases__: - method = cls._find_visitor_cached(base) + method = cls._find_visitor_class_method(base) if method: return method return None @classmethod - def _find_visitor_cached(cls, node_cls: "type[Any]") -> "Callable[..., Any]|None": - result = cls.__cls_finder_cache__.get(node_cls, NOT_SET) + def _find_visitor(cls, node_cls: "type[Any]") -> Callable[..., TVisitorResult]: + result = cls.__visitor_finder_cache[node_cls] if result is NOT_SET: - result = cls.__cls_finder_cache__[node_cls] = cls.__find_visitor(node_cls) - return result # type: ignore[return-value] + result = cls.__visitor_finder_cache[node_cls] = cls._find_visitor_class_method(node_cls) + return result or cls.__default_visitor_method # type: ignore[return-value] def _iter_field_values(node: Node) -> "Iterator[Node|list[Node]|None]": @@ -70,7 +75,7 @@ def iter_fields(node: Node) -> "Iterator[tuple[str, Node|list[Node]|None]]": pass -class ModelVisitor(VisitorFinder): +class ModelVisitor(VisitorFinder[None]): """NodeVisitor that supports matching nodes based on their base classes. In other ways identical to the standard `ast.NodeVisitor @@ -84,7 +89,7 @@ def visit_Statement(self, node): """ def visit(self, node: Node) -> None: - visitor = self._find_visitor_cached(type(node)) or self.__class__.generic_visit + visitor = self._find_visitor(type(node)) visitor(self, node) def generic_visit(self, node: Node) -> None: @@ -98,7 +103,7 @@ def generic_visit(self, node: Node) -> None: self.visit(value) -class ModelTransformer(VisitorFinder): +class ModelTransformer(VisitorFinder["Node|list[Node]|None"]): """NodeTransformer that supports matching nodes based on their base classes. See :class:`ModelVisitor` for explanation how this is different compared @@ -107,7 +112,7 @@ class ModelTransformer(VisitorFinder): """ def visit(self, node: Node) -> "Node|list[Node]|None": - visitor = self._find_visitor_cached(type(node)) or self.__class__.generic_visit + visitor = self._find_visitor(type(node)) return visitor(self, node) def generic_visit(self, node: Node) -> "Node|list[Node]|None": From ebc847b350ef583b79c1265ee559c84f2bea9ae7 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Wed, 25 Oct 2023 14:58:51 +0200 Subject: [PATCH 7/8] Revert "remove unneeded _field attribute in Statement" This reverts commit bf5b3a2147307c65fb3b22e3871c9bd8280ebe63. --- src/robot/parsing/model/statements.py | 2 +- utest/parsing/parsing_test_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/robot/parsing/model/statements.py b/src/robot/parsing/model/statements.py index d4e37850804..3a8b8c10748 100644 --- a/src/robot/parsing/model/statements.py +++ b/src/robot/parsing/model/statements.py @@ -47,7 +47,7 @@ class Node(ast.AST, ABC): class Statement(Node, ABC): - _attributes = (*Node._attributes, 'tokens', 'type') + _fields = ('type', 'tokens') type: str handles_types: 'ClassVar[tuple[str, ...]]' = () statement_handlers: 'ClassVar[dict[str, Type[Statement]]]' = {} diff --git a/utest/parsing/parsing_test_utils.py b/utest/parsing/parsing_test_utils.py index 84695f14374..8dff1b8dc5d 100644 --- a/utest/parsing/parsing_test_utils.py +++ b/utest/parsing/parsing_test_utils.py @@ -47,13 +47,13 @@ def assert_block(model, expected, expected_attrs): def assert_statement(model, expected): - assert_equal(model._fields, ()) + assert_equal(model._fields, ('type', 'tokens')) assert_equal(model.type, expected.type) assert_equal(len(model.tokens), len(expected.tokens)) for m, e in zip(model.tokens, expected.tokens): assert_equal(m, e, formatter=repr) assert_equal(model._attributes, ('lineno', 'col_offset', 'end_lineno', - 'end_col_offset', 'errors', 'tokens', 'type')) + 'end_col_offset', 'errors')) assert_equal(model.lineno, expected.tokens[0].lineno) assert_equal(model.col_offset, expected.tokens[0].col_offset) assert_equal(model.end_lineno, expected.tokens[-1].lineno) From 469dfe2473a52677416bc61a2701a9b17d886fa0 Mon Sep 17 00:00:00 2001 From: Daniel Biehl Date: Wed, 25 Oct 2023 20:58:14 +0200 Subject: [PATCH 8/8] cleanup, remove unwanted typehints and some small cosmetic changes in VisitorFinder --- src/robot/parsing/model/visitor.py | 91 ++++++------------------------ 1 file changed, 18 insertions(+), 73 deletions(-) diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index aa12052ae50..8d4b8e25d34 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -13,38 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC -from ast import AST -from collections import defaultdict -from typing import Any, Callable, Generic, Iterator, TypeVar +from abc import ABC, abstractmethod +from ast import AST, NodeTransformer, NodeVisitor -from ...utils.notset import NOT_SET, NotSet from .statements import Node -TVisitorResult = TypeVar("TVisitorResult") - -class VisitorFinder(ABC, Generic[TVisitorResult]): - __visitor_finder_cache: "dict[type[Any], Callable[[Node], Any]|None|NotSet]" - __default_visitor_method: "Callable[..., TVisitorResult]" - - def __init_subclass__(cls, default_visitor_name: str = "generic_visit", **kwargs: Any) -> None: +class VisitorFinder(ABC): + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls.__visitor_finder_cache = defaultdict(lambda: NOT_SET) - cls.__default_visitor_method = getattr(cls, default_visitor_name) + cls.__visitor_finder_cache = {} @classmethod - def _find_visitor_class_method(cls, node_cls: "type[Any]") -> "Callable[..., TVisitorResult]|None": + def _find_visitor_class_method(cls, node_cls): if node_cls is AST: return None method_name = "visit_" + node_cls.__name__ method = getattr(cls, method_name, None) if callable(method): - return method # type: ignore[no-any-return] + return method if method_name == "visit_Return": method = getattr(cls, "visit_ReturnSetting", None) if callable(method): - return method # type: ignore[no-any-return] + return method for base in node_cls.__bases__: method = cls._find_visitor_class_method(base) if method: @@ -52,30 +43,18 @@ def _find_visitor_class_method(cls, node_cls: "type[Any]") -> "Callable[..., TVi return None @classmethod - def _find_visitor(cls, node_cls: "type[Any]") -> Callable[..., TVisitorResult]: - result = cls.__visitor_finder_cache[node_cls] - if result is NOT_SET: - result = cls.__visitor_finder_cache[node_cls] = cls._find_visitor_class_method(node_cls) - return result or cls.__default_visitor_method # type: ignore[return-value] - + def _find_visitor(cls, node_cls): + if node_cls in cls.__visitor_finder_cache: + return cls.__visitor_finder_cache[node_cls] + result = cls.__visitor_finder_cache[node_cls] = cls._find_visitor_class_method(node_cls) or cls.generic_visit + return result -def _iter_field_values(node: Node) -> "Iterator[Node|list[Node]|None]": - for field in node._fields: - try: - yield getattr(node, field) - except AttributeError: - pass + @abstractmethod + def generic_visit(self, node): + ... -def iter_fields(node: Node) -> "Iterator[tuple[str, Node|list[Node]|None]]": - for field in node._fields: - try: - yield field, getattr(node, field) - except AttributeError: - pass - - -class ModelVisitor(VisitorFinder[None]): +class ModelVisitor(NodeVisitor, VisitorFinder): """NodeVisitor that supports matching nodes based on their base classes. In other ways identical to the standard `ast.NodeVisitor @@ -92,18 +71,8 @@ def visit(self, node: Node) -> None: visitor = self._find_visitor(type(node)) visitor(self, node) - def generic_visit(self, node: Node) -> None: - for value in _iter_field_values(node): - if value is None: - continue - if isinstance(value, list): - for item in value: - self.visit(item) - else: - self.visit(value) - -class ModelTransformer(VisitorFinder["Node|list[Node]|None"]): +class ModelTransformer(NodeTransformer, VisitorFinder): """NodeTransformer that supports matching nodes based on their base classes. See :class:`ModelVisitor` for explanation how this is different compared @@ -114,27 +83,3 @@ class ModelTransformer(VisitorFinder["Node|list[Node]|None"]): def visit(self, node: Node) -> "Node|list[Node]|None": visitor = self._find_visitor(type(node)) return visitor(self, node) - - def generic_visit(self, node: Node) -> "Node|list[Node]|None": - for field, old_value in iter_fields(node): - if old_value is None: - continue - if isinstance(old_value, list): - new_values = [] - for value in old_value: - new_value = self.visit(value) - if new_value is None: - continue - if isinstance(new_value, list): - new_values.extend(new_value) - continue - new_values.append(new_value) - old_value[:] = new_values - else: - new_node = self.visit(old_value) - if new_node is None: - delattr(node, field) - else: - setattr(node, field, new_node) - - return node