Thanks to visit codestin.com
Credit goes to github.com

Skip to content

performance optimization in visiting models #4911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 7, 2023
62 changes: 40 additions & 22 deletions src/robot/parsing/model/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
from abc import ABC, abstractmethod
from ast import AST, NodeTransformer, NodeVisitor

from .statements import Node


class VisitorFinder:
class VisitorFinder(ABC):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__visitor_finder_cache = {}

def _find_visitor(self, cls):
if cls is ast.AST:
@classmethod
def _find_visitor_class_method(cls, node_cls):
if node_cls is AST:
return None
method = 'visit_' + cls.__name__
if hasattr(self, method):
return getattr(self, method)
# Forward-compatibility.
if method == 'visit_Return' and hasattr(self, 'visit_ReturnSetting'):
return getattr(self, 'visit_ReturnSetting')
for base in cls.__bases__:
visitor = self._find_visitor(base)
if visitor:
return visitor
method_name = "visit_" + node_cls.__name__
method = getattr(cls, method_name, None)
if callable(method):
return method
if method_name == "visit_Return":
method = getattr(cls, "visit_ReturnSetting", None)
if callable(method):
return method
for base in node_cls.__bases__:
method = cls._find_visitor_class_method(base)
if method:
return method
return None

@classmethod
def _find_visitor(cls, node_cls):
if node_cls in cls.__visitor_finder_cache:
return cls.__visitor_finder_cache[node_cls]
result = cls.__visitor_finder_cache[node_cls] = cls._find_visitor_class_method(node_cls) or cls.generic_visit
return result

class ModelVisitor(ast.NodeVisitor, VisitorFinder):
@abstractmethod
def generic_visit(self, node):
...


class ModelVisitor(NodeVisitor, VisitorFinder):
"""NodeVisitor that supports matching nodes based on their base classes.

In other ways identical to the standard `ast.NodeVisitor
Expand All @@ -49,19 +67,19 @@ def visit_Statement(self, node):
...
"""

def visit(self, node: Node):
visitor = self._find_visitor(type(node)) or self.generic_visit
visitor(node)
def visit(self, node: Node) -> None:
visitor = self._find_visitor(type(node))
visitor(self, node)


class ModelTransformer(ast.NodeTransformer, VisitorFinder):
class ModelTransformer(NodeTransformer, VisitorFinder):
"""NodeTransformer that supports matching nodes based on their base classes.

See :class:`ModelVisitor` for explanation how this is different compared
to the standard `ast.NodeTransformer
<https://docs.python.org/library/ast.html#ast.NodeTransformer>`__.
"""

def visit(self, node: Node):
visitor = self._find_visitor(type(node)) or self.generic_visit
return visitor(node)
def visit(self, node: Node) -> "Node|list[Node]|None":
visitor = self._find_visitor(type(node))
return visitor(self, node)