diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 57141b6062c65..c65720fdf6597 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -7,6 +7,9 @@ from torch.testing._internal.common_utils import TestCase +CREATE_STR = NodeSourceAction.CREATE.name.lower() + + class TestFXNodeSource(TestCase): def test_node_source(self): node_source = NodeSource( @@ -20,7 +23,7 @@ def test_node_source(self): "name": "", "target": "", "pass_name": "test_pass", - "action": NodeSourceAction.CREATE, + "action": CREATE_STR, "graph_id": -1, "from_node": [], } @@ -56,7 +59,7 @@ def test_node_source(self): "name": "add", "target": "aten.add.Tensor", "pass_name": "test_pass", - "action": NodeSourceAction.CREATE, + "action": CREATE_STR, "graph_id": graph_id, "from_node": [dummy_source_dict], }, @@ -108,7 +111,7 @@ def forward(self, x): key_provenance, "x", "Interpreter_PropagateUnbackedSymInts", - NodeSourceAction.CREATE, + CREATE_STR, ) # Check node "x" is then created from another node "x" in FlattenInputOutputSignature @@ -117,7 +120,7 @@ def forward(self, x): key_provenance, "x", "Interpreter_FlattenInputOutputSignature", - NodeSourceAction.CREATE, + CREATE_STR, ) gm, graph_signature = aot_export_module( @@ -147,7 +150,7 @@ def forward(self, x): key_provenance, "linear", "Interpreter_PropagateUnbackedSymInts", - NodeSourceAction.CREATE, + CREATE_STR, ) # Check node "linear" is then created from node "x" in PropagateUnbackedSymInts @@ -156,7 +159,7 @@ def forward(self, x): key_provenance, "x", "Interpreter_PropagateUnbackedSymInts", - NodeSourceAction.CREATE, + CREATE_STR, ) # Check node "x" is then created from another node "x" in FlattenInputOutputSignature @@ -165,5 +168,5 @@ def forward(self, x): key_provenance, "x", "Interpreter_FlattenInputOutputSignature", - NodeSourceAction.CREATE, + CREATE_STR, ) diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index 95b0ee74f698e..b272af9b17f1c 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -1,11 +1,13 @@ # Owner(s): ["module: fx"] +import copy import os import tempfile import torch from torch.fx import subgraph_rewriter, symbolic_trace from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.traceback import NodeSourceAction from torch.testing._internal.common_utils import TestCase @@ -60,3 +62,127 @@ def replacement(x): ) ) ) + + @torch._inductor.config.patch("trace.enabled", True) + def test_graph_transform_observer_node_tracking(self): + class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + return torch.add(val, val) + + def pattern(x): + return torch.neg(x) + + def replacement(x): + return torch.relu(x) + + def replacement2(x): + return torch.cos(x) + + traced = symbolic_trace(M()) + + def check_node_source(node_source, node_name, target, id, pass_name, action): + self.assertEqual(node_source.name, node_name) + self.assertEqual(node_source.target, target) + self.assertEqual(node_source.pass_name, pass_name) + self.assertEqual(node_source.graph_id, id) + self.assertEqual(node_source.action, action) + + with GraphTransformObserver(traced, "replace_neg_with_relu") as ob: + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + self.assertTrue("relu" in ob.created_nodes) + self.assertTrue("neg" in ob.erased_nodes) + + self.assertEqual(len(traced._replace_hooks), 0) + self.assertEqual(len(traced._create_node_hooks), 0) + self.assertEqual(len(traced._erase_node_hooks), 0) + self.assertEqual(len(traced._deepcopy_hooks), 0) + + for node in traced.graph.nodes: + if node.name == "relu": + from_node = node.meta["from_node"] + self.assertTrue(len(from_node) == 1) + check_node_source( + from_node[0], + "neg", + str(torch.neg), + id(traced.graph), + "replace_neg_with_relu", + [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], + ) + + with GraphTransformObserver(traced, "replace_relu_with_cos") as ob: + subgraph_rewriter.replace_pattern(traced, replacement, replacement2) + + self.assertTrue("cos" in ob.created_nodes) + self.assertTrue("relu" in ob.erased_nodes) + + for node in traced.graph.nodes: + if node.name == "cos": + from_node = node.meta["from_node"] + self.assertTrue(len(from_node) == 1) + check_node_source( + from_node[0], + "relu", + str(torch.relu), + id(traced.graph), + "replace_relu_with_cos", + [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], + ) + check_node_source( + from_node[0].from_node[0], + "neg", + str(torch.neg), + id(traced.graph), + "replace_neg_with_relu", + [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], + ) + + class SimpleLinearModel(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + + model = SimpleLinearModel() + gm = torch.export.export(model, (torch.rand(10),)).module() + + with GraphTransformObserver(gm, "test"): + add_node = gm.graph.call_function(torch.ops.aten.add.default, (1, 1)) + neg_node = next( + iter([node for node in gm.graph.nodes if node.name == "neg"]) + ) + neg_node.replace_all_uses_with(replace_with=add_node) + + from_node = add_node.meta["from_node"] + self.assertTrue(len(from_node) == 1) + check_node_source( + from_node[0], + "neg", + str(torch.ops.aten.neg.default), + id(gm.graph), + "test", + [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], + ) + + @torch._inductor.config.patch("trace.enabled", True) + def test_graph_transform_observer_deepcopy(self): + class SimpleLinearModel(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + + model = SimpleLinearModel() + gm = torch.export.export(model, (torch.rand(10),)).module() + + with GraphTransformObserver(gm, "test"): + gm2 = copy.deepcopy(gm) + + nodes = [node.name for node in gm.graph.nodes] + nodes2 = [node.name for node in gm2.graph.nodes] + self.assertEqual(nodes, nodes2) + + # deepcopied graph modules should not have hooks after exiting + # the context + self.assertEqual(len(gm2._replace_hooks), 0) + self.assertEqual(len(gm2._create_node_hooks), 0) + self.assertEqual(len(gm2._erase_node_hooks), 0) + self.assertEqual(len(gm2._deepcopy_hooks), 0) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index ec0e83d0f525b..c435e57ac35ee 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -931,6 +931,18 @@ def log_graph_runnable() -> str: print_output=False, include_stride=True, include_device=True ), ) + if config.trace.enabled: + provenance_tracking_json = ( + torch.fx.traceback.get_graph_provenance_json(gm.graph) + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_post_to_pre_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: provenance_tracking_json, + ) if config.is_fbcode(): log_optimus_to_scuba( extra_logging={"pt2_configs": str(get_patched_config_dict())} @@ -1636,7 +1648,9 @@ def compile_fx( with _use_lazy_graph_module( dynamo_config.use_lazy_graph_module - ), enable_python_dispatcher(): + ), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta( + config.trace.enabled + ): # Pre-grad passes cannot be run if we weren't given a GraphModule. # Dynamo will always produce a GraphModule, but this handles cases # where a user directly passes a plain Module with the intention of diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 8161dc2618fd5..926a33f613f91 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -138,13 +138,30 @@ def __init__(self) -> None: MULTIPLE = Multiple() -def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None: +def _transfer_meta( + new_meta: Dict[str, Any], old_node: torch.fx.Node, pass_name: str = "" +) -> None: + from torch.fx.traceback import NodeSource, NodeSourceAction + # transfer metadata after pattern matching occurs. # skip "val" and "tensor_meta" because this info is too specific; it's unlikely # to remain accurate after pattern matching has occurred. - new_meta.update( - (k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS - ) + if config.trace.enabled: + # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. + new_from_node = new_meta.get("from_node", []).copy() + new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) + new_meta["from_node"] = new_from_node + else: + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) class Match: @@ -263,7 +280,11 @@ def replace_by_example( ) if len(self.nodes) == 1: for n in replacement.graph.nodes: - _transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta) + _transfer_meta( + new_meta=n.meta, + old_node=self.nodes[0], + pass_name="replace_by_example", + ) ReplacementPatternEntry.replace_with_graph( self, @@ -1069,7 +1090,11 @@ def run_node(self, node: torch.fx.Node) -> Any: target = node.target args, kwargs = self.fetch_args_kwargs_from_env(node) result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] - _transfer_meta(new_meta=result.meta, old_meta=node.meta) + _transfer_meta( + new_meta=result.meta, + old_node=node, + pass_name="Interpreter_Replacer", + ) if "val" in node.meta and "val" not in result.meta: result.meta["val"] = node.meta["val"] if isinstance(node.meta["val"], torch.Tensor): @@ -1401,7 +1426,8 @@ def search_fn_new(*args_new: Any) -> Any: for n in match.replacement_graph.graph.nodes: _transfer_meta( new_meta=n.meta, - old_meta=match.nodes[0].meta, + old_node=match.nodes[0], + pass_name="replacement", ) return True return False diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 52c55586ab51f..6a1104c7f40bc 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -531,6 +531,8 @@ def __init__( self._replace_hooks: List[Callable] = [] self._create_node_hooks: List[Callable] = [] self._erase_node_hooks: List[Callable] = [] + # Used to remove hooks from deepcopied graph modules within a context manager. + self._deepcopy_hooks: List[Callable] = [] # TorchScript breaks trying to compile the graph setter because of the # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 @@ -888,6 +890,7 @@ def __deepcopy__(self, memo): "_replace_hooks", "_create_node_hooks", "_erase_node_hooks", + "_deepcopy_hooks", ] for attr in extra_preserved_attrs: if attr in self.__dict__: @@ -896,6 +899,8 @@ def __deepcopy__(self, memo): if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): setattr(res, attr_name, attr) + for hook in self._deepcopy_hooks: + hook(res) return res def __copy__(self): @@ -1002,6 +1007,22 @@ def _unregister_erase_node_hook(self, f): assert callable(f), "erase_node hook must be a callable." self._erase_node_hooks.remove(f) + def _register_deepcopy_hook(self, f): + """ + Takes a callable which will be called when we deepcopy this graph module. The + callable takes the resulting deepcopied graph module. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.append(f) + + def _unregister_deepcopy_hook(self, f): + """ + Takes a callable which was previously registered to be called after deepcopy. + This function will unregister that callable so it is no longer invoked on deepcopy. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.remove(f) + # workarounds for issues in __torch_function__ diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 2f27cf3c3866a..d72a7599f3499 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import os -from typing import Callable, Optional, TypeVar +from typing import Callable, Dict, List, Optional, Set, TypeVar -from torch.fx import Graph +from torch.fx import Graph, Node from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule +from torch.fx.traceback import NodeSource, NodeSourceAction T = TypeVar("T") @@ -30,18 +31,32 @@ def __init__( """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ + from torch._inductor.config import trace self.gm = gm self.passname = passname self.subsystem = subsystem - # If log_url is None, we don't log anything if log_url is None: - from torch._inductor.config import trace - log_url = trace.log_url_for_graph_xform self.log_url = log_url + + self.active = trace.enabled or self.log_url is not None + + if self.active: + self.erased_nodes: Set[str] = set() + self.created_nodes: Set[str] = set() + self.name_to_node: Dict[str, Node] = {} + # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context + self.copied_gms: List[GraphModule] = [] + + self._node_creation_hook = self.get_node_creation_hook() + self._node_erase_hook = self.get_node_erase_hook() + self._node_replace_hook = self.get_node_replace_hook() + self._deepcopy_hook = self.get_deepcopy_hook() + + # If log_url is None, we don't log anything if self.log_url is None: return GraphTransformObserver.__pass_count += 1 @@ -83,22 +98,34 @@ def _check_disable_pass(self): ) def __enter__(self): - if self.log_url is None or self.gm is None: + if not self.active: return self + self.gm._register_create_node_hook(self._node_creation_hook) + self.gm._register_erase_node_hook(self._node_erase_hook) + self.gm._register_replace_node_hook(self._node_replace_hook) + self.gm._register_deepcopy_hook(self._deepcopy_hook) + + self.erased_nodes.clear() + self.created_nodes.clear() + self.name_to_node.clear() + self.copied_gms.clear() - self.erased_nodes = set() - self.created_nodes = set() - self.gm._register_create_node_hook(self.on_node_creation) - self.gm._register_erase_node_hook(self.on_node_erase) + for node in self.gm.graph.nodes: + self.name_to_node[node.name] = node return self def __exit__(self, type, value, tb): - if self.log_url is None or self.gm is None: + if not self.active: return + for gm in self.copied_gms + [self.gm]: + gm._unregister_create_node_hook(self._node_creation_hook) + gm._unregister_erase_node_hook(self._node_erase_hook) + gm._unregister_replace_node_hook(self._node_replace_hook) + gm._unregister_deepcopy_hook(self._deepcopy_hook) - self.gm._unregister_create_node_hook(self.on_node_creation) - self.gm._unregister_erase_node_hook(self.on_node_erase) + if self.log_url is None: + return if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: for e in self.input_dot_graph.get_node_list(): @@ -106,6 +133,7 @@ def __exit__(self, type, value, tb): e.obj_dict["attributes"]["fillcolor"] = "yellow" else: e.obj_dict["attributes"]["fillcolor"] = "grey" + assert self.log_url is not None self.input_dot_graph.write( os.path.join( self.log_url, @@ -131,8 +159,61 @@ def __exit__(self, type, value, tb): ) ) - def on_node_creation(self, node): - self.created_nodes.add(node.name) + def get_node_creation_hook(self): + # We have to return a function instead of using a class method directly + # to avoid max recursion issue when deepcopy a graph module within the context manager. + def on_node_creation(node): + self.created_nodes.add(node.name) + self.name_to_node[node.name] = node + source = NodeSource(None, self.passname, NodeSourceAction.CREATE) + if "from_node" not in node.meta: + node.meta["from_node"] = [source] + else: + node.meta["from_node"].append(source) + + return on_node_creation + + def get_node_erase_hook(self): + def on_node_erase(node): + self.erased_nodes.add(node.name) + self.name_to_node.pop(node.name, None) + + return on_node_erase + + def get_node_replace_hook(self): + def on_node_replace(old: Node, new: str, user: Node): + # Update node meta when replacing old node with new node + new_node = self.name_to_node.get(new, None) + + if not new_node: + return + + assert isinstance(new_node, Node) + + action = [NodeSourceAction.REPLACE] + if new_node.name in self.created_nodes: + action.append(NodeSourceAction.CREATE) + + def created_this_pass(source): + return source.pass_name == self.passname and source.action == [ + NodeSourceAction.CREATE + ] + + # remove redundant source added on node creation + new_from_node = new_node.meta.get("from_node", []) + new_from_node = [ + source for source in new_from_node if not created_this_pass(source) + ] + + # add new source + new_node_source = NodeSource(old, self.passname, action) + new_from_node.append(new_node_source) + new_node.meta["from_node"] = new_from_node + + return on_node_replace + + def get_deepcopy_hook(self): + def on_deepcopy(gm): + self.copied_gms.append(gm) - def on_node_erase(self, node): - self.erased_nodes.add(node.name) + return on_deepcopy diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 2a4cccc419a29..88a8fc54fa5f7 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -4,7 +4,7 @@ import traceback from contextlib import contextmanager from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from ._compatibility import compatibility from .graph import Graph @@ -30,7 +30,7 @@ @compatibility(is_backward_compatible=False) -class NodeSourceAction(str, Enum): +class NodeSourceAction(Enum): CREATE = "create" REPLACE = "replace" @@ -49,7 +49,7 @@ def __init__(self, name: str, target: str, graph_id: int): self.graph_id = graph_id pass_name: str - action: Optional["NodeSourceAction"] + action: List["NodeSourceAction"] from_node: List["NodeSource"] node_info: Optional["NodeInfo"] @@ -57,9 +57,16 @@ def __init__( self, node: Optional[Node], pass_name: str = "", - action: Optional["NodeSourceAction"] = None, + action: Optional[Union["NodeSourceAction", List["NodeSourceAction"]]] = None, ): self.pass_name = pass_name + + if action is None: + action = [] + elif not isinstance(action, list): + action = [action] + for a in action: + assert isinstance(a, NodeSourceAction) self.action = action if node: self.node_info = self.NodeInfo( @@ -89,13 +96,17 @@ def graph_id(self) -> int: def __repr__(self): return self.print_readable() + def _get_action_string(self): + return "+".join([a.name.lower() for a in self.action]) + def print_readable(self, indent=0): if indent > 9: return "" result = "" + action_string = self._get_action_string() result += ( " " * indent * 4 - + f"(name={self.name}, pass_name={self.pass_name}, action={self.action}, graph_id={self.graph_id})\n" + + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n" ) for item in self.from_node: result += item.print_readable(indent + 1) @@ -103,31 +114,35 @@ def print_readable(self, indent=0): def to_dict(self) -> dict: # Convert the object to a dictionary + action_string = self._get_action_string() return { "name": self.name, "target": self.target, "graph_id": self.graph_id, "pass_name": self.pass_name, - "action": self.action, + "action": action_string, "from_node": [node.to_dict() for node in self.from_node], } @compatibility(is_backward_compatible=False) @contextmanager -def preserve_node_meta(): +def preserve_node_meta(enable=True): global should_preserve_node_meta global current_meta - - saved_should_preserve_node_meta = should_preserve_node_meta - # Shallow copy is OK since fields of current_meta are not mutated - saved_current_meta = current_meta.copy() - try: - should_preserve_node_meta = True + # If enable is False, this context manager is a no-op + if not enable: yield - finally: - should_preserve_node_meta = saved_should_preserve_node_meta - current_meta = saved_current_meta + else: + saved_should_preserve_node_meta = should_preserve_node_meta + # Shallow copy is OK since fields of current_meta are not mutated + saved_current_meta = current_meta.copy() + try: + should_preserve_node_meta = True + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + current_meta = saved_current_meta @compatibility(is_backward_compatible=False)