diff --git a/src/robot/parsing/model/visitor.py b/src/robot/parsing/model/visitor.py index fb176459fb5..8d4b8e25d34 100644 --- a/src/robot/parsing/model/visitor.py +++ b/src/robot/parsing/model/visitor.py @@ -13,30 +13,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ast +from abc import ABC, abstractmethod +from ast import AST, NodeTransformer, NodeVisitor from .statements import Node -class VisitorFinder: +class VisitorFinder(ABC): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.__visitor_finder_cache = {} - def _find_visitor(self, cls): - if cls is ast.AST: + @classmethod + def _find_visitor_class_method(cls, node_cls): + if node_cls is AST: return None - method = 'visit_' + cls.__name__ - if hasattr(self, method): - return getattr(self, method) - # Forward-compatibility. - if method == 'visit_Return' and hasattr(self, 'visit_ReturnSetting'): - return getattr(self, 'visit_ReturnSetting') - for base in cls.__bases__: - visitor = self._find_visitor(base) - if visitor: - return visitor + method_name = "visit_" + node_cls.__name__ + method = getattr(cls, method_name, None) + if callable(method): + return method + if method_name == "visit_Return": + method = getattr(cls, "visit_ReturnSetting", None) + if callable(method): + return method + for base in node_cls.__bases__: + method = cls._find_visitor_class_method(base) + if method: + return method return None + @classmethod + def _find_visitor(cls, node_cls): + if node_cls in cls.__visitor_finder_cache: + return cls.__visitor_finder_cache[node_cls] + result = cls.__visitor_finder_cache[node_cls] = cls._find_visitor_class_method(node_cls) or cls.generic_visit + return result -class ModelVisitor(ast.NodeVisitor, VisitorFinder): + @abstractmethod + def generic_visit(self, node): + ... + + +class ModelVisitor(NodeVisitor, VisitorFinder): """NodeVisitor that supports matching nodes based on their base classes. In other ways identical to the standard `ast.NodeVisitor @@ -49,12 +67,12 @@ def visit_Statement(self, node): ... """ - def visit(self, node: Node): - visitor = self._find_visitor(type(node)) or self.generic_visit - visitor(node) + def visit(self, node: Node) -> None: + visitor = self._find_visitor(type(node)) + visitor(self, node) -class ModelTransformer(ast.NodeTransformer, VisitorFinder): +class ModelTransformer(NodeTransformer, VisitorFinder): """NodeTransformer that supports matching nodes based on their base classes. See :class:`ModelVisitor` for explanation how this is different compared @@ -62,6 +80,6 @@ class ModelTransformer(ast.NodeTransformer, VisitorFinder): `__. """ - def visit(self, node: Node): - visitor = self._find_visitor(type(node)) or self.generic_visit - return visitor(node) + def visit(self, node: Node) -> "Node|list[Node]|None": + visitor = self._find_visitor(type(node)) + return visitor(self, node)