Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions test/fx/test_fx_traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from torch.testing._internal.common_utils import TestCase


CREATE_STR = NodeSourceAction.CREATE.name.lower()


class TestFXNodeSource(TestCase):
def test_node_source(self):
node_source = NodeSource(
Expand All @@ -20,7 +23,7 @@ def test_node_source(self):
"name": "",
"target": "",
"pass_name": "test_pass",
"action": NodeSourceAction.CREATE,
"action": CREATE_STR,
"graph_id": -1,
"from_node": [],
}
Expand Down Expand Up @@ -56,7 +59,7 @@ def test_node_source(self):
"name": "add",
"target": "aten.add.Tensor",
"pass_name": "test_pass",
"action": NodeSourceAction.CREATE,
"action": CREATE_STR,
"graph_id": graph_id,
"from_node": [dummy_source_dict],
},
Expand Down Expand Up @@ -108,7 +111,7 @@ def forward(self, x):
key_provenance,
"x",
"Interpreter_PropagateUnbackedSymInts",
NodeSourceAction.CREATE,
CREATE_STR,
)

# Check node "x" is then created from another node "x" in FlattenInputOutputSignature
Expand All @@ -117,7 +120,7 @@ def forward(self, x):
key_provenance,
"x",
"Interpreter_FlattenInputOutputSignature",
NodeSourceAction.CREATE,
CREATE_STR,
)

gm, graph_signature = aot_export_module(
Expand Down Expand Up @@ -147,7 +150,7 @@ def forward(self, x):
key_provenance,
"linear",
"Interpreter_PropagateUnbackedSymInts",
NodeSourceAction.CREATE,
CREATE_STR,
)

# Check node "linear" is then created from node "x" in PropagateUnbackedSymInts
Expand All @@ -156,7 +159,7 @@ def forward(self, x):
key_provenance,
"x",
"Interpreter_PropagateUnbackedSymInts",
NodeSourceAction.CREATE,
CREATE_STR,
)

# Check node "x" is then created from another node "x" in FlattenInputOutputSignature
Expand All @@ -165,5 +168,5 @@ def forward(self, x):
key_provenance,
"x",
"Interpreter_FlattenInputOutputSignature",
NodeSourceAction.CREATE,
CREATE_STR,
)
126 changes: 126 additions & 0 deletions test/fx/test_fx_xform_observer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Owner(s): ["module: fx"]

import copy
import os
import tempfile

import torch
from torch.fx import subgraph_rewriter, symbolic_trace
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.traceback import NodeSourceAction
from torch.testing._internal.common_utils import TestCase


Expand Down Expand Up @@ -60,3 +62,127 @@ def replacement(x):
)
)
)

@torch._inductor.config.patch("trace.enabled", True)
def test_graph_transform_observer_node_tracking(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x)
return torch.add(val, val)

def pattern(x):
return torch.neg(x)

def replacement(x):
return torch.relu(x)

def replacement2(x):
return torch.cos(x)

traced = symbolic_trace(M())

def check_node_source(node_source, node_name, target, id, pass_name, action):
self.assertEqual(node_source.name, node_name)
self.assertEqual(node_source.target, target)
self.assertEqual(node_source.pass_name, pass_name)
self.assertEqual(node_source.graph_id, id)
self.assertEqual(node_source.action, action)

with GraphTransformObserver(traced, "replace_neg_with_relu") as ob:
subgraph_rewriter.replace_pattern(traced, pattern, replacement)

self.assertTrue("relu" in ob.created_nodes)
self.assertTrue("neg" in ob.erased_nodes)

self.assertEqual(len(traced._replace_hooks), 0)
self.assertEqual(len(traced._create_node_hooks), 0)
self.assertEqual(len(traced._erase_node_hooks), 0)
self.assertEqual(len(traced._deepcopy_hooks), 0)

for node in traced.graph.nodes:
if node.name == "relu":
from_node = node.meta["from_node"]
self.assertTrue(len(from_node) == 1)
check_node_source(
from_node[0],
"neg",
str(torch.neg),
id(traced.graph),
"replace_neg_with_relu",
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
)

with GraphTransformObserver(traced, "replace_relu_with_cos") as ob:
subgraph_rewriter.replace_pattern(traced, replacement, replacement2)

self.assertTrue("cos" in ob.created_nodes)
self.assertTrue("relu" in ob.erased_nodes)

for node in traced.graph.nodes:
if node.name == "cos":
from_node = node.meta["from_node"]
self.assertTrue(len(from_node) == 1)
check_node_source(
from_node[0],
"relu",
str(torch.relu),
id(traced.graph),
"replace_relu_with_cos",
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
)
check_node_source(
from_node[0].from_node[0],
"neg",
str(torch.neg),
id(traced.graph),
"replace_neg_with_relu",
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
)

class SimpleLinearModel(torch.nn.Module):
def forward(self, x):
return torch.neg(x)

model = SimpleLinearModel()
gm = torch.export.export(model, (torch.rand(10),)).module()

with GraphTransformObserver(gm, "test"):
add_node = gm.graph.call_function(torch.ops.aten.add.default, (1, 1))
neg_node = next(
iter([node for node in gm.graph.nodes if node.name == "neg"])
)
neg_node.replace_all_uses_with(replace_with=add_node)

from_node = add_node.meta["from_node"]
self.assertTrue(len(from_node) == 1)
check_node_source(
from_node[0],
"neg",
str(torch.ops.aten.neg.default),
id(gm.graph),
"test",
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
)

@torch._inductor.config.patch("trace.enabled", True)
def test_graph_transform_observer_deepcopy(self):
class SimpleLinearModel(torch.nn.Module):
def forward(self, x):
return torch.neg(x)

model = SimpleLinearModel()
gm = torch.export.export(model, (torch.rand(10),)).module()

with GraphTransformObserver(gm, "test"):
gm2 = copy.deepcopy(gm)

nodes = [node.name for node in gm.graph.nodes]
nodes2 = [node.name for node in gm2.graph.nodes]
self.assertEqual(nodes, nodes2)

# deepcopied graph modules should not have hooks after exiting
# the context
self.assertEqual(len(gm2._replace_hooks), 0)
self.assertEqual(len(gm2._create_node_hooks), 0)
self.assertEqual(len(gm2._erase_node_hooks), 0)
self.assertEqual(len(gm2._deepcopy_hooks), 0)
16 changes: 15 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,18 @@ def log_graph_runnable() -> str:
print_output=False, include_stride=True, include_device=True
),
)
if config.trace.enabled:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a discussion on deprecating TORCH_COMPILE_DEBUG. Ok for now, but we may want to switch to TORCH_LOGS in future.

provenance_tracking_json = (
torch.fx.traceback.get_graph_provenance_json(gm.graph)
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_post_to_pre_grad_nodes",
"encoding": "json",
},
payload_fn=lambda: provenance_tracking_json,
)
if config.is_fbcode():
log_optimus_to_scuba(
extra_logging={"pt2_configs": str(get_patched_config_dict())}
Expand Down Expand Up @@ -1636,7 +1648,9 @@ def compile_fx(

with _use_lazy_graph_module(
dynamo_config.use_lazy_graph_module
), enable_python_dispatcher():
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta(
config.trace.enabled
):
# Pre-grad passes cannot be run if we weren't given a GraphModule.
# Dynamo will always produce a GraphModule, but this handles cases
# where a user directly passes a plain Module with the intention of
Expand Down
40 changes: 33 additions & 7 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,30 @@ def __init__(self) -> None:
MULTIPLE = Multiple()


def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None:
def _transfer_meta(
new_meta: Dict[str, Any], old_node: torch.fx.Node, pass_name: str = ""
) -> None:
from torch.fx.traceback import NodeSource, NodeSourceAction

# transfer metadata after pattern matching occurs.
# skip "val" and "tensor_meta" because this info is too specific; it's unlikely
# to remain accurate after pattern matching has occurred.
new_meta.update(
(k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS
)
if config.trace.enabled:
# We handle "from_node" field of the node meta specially to record that the new node comes from the old_node.
new_from_node = new_meta.get("from_node", []).copy()
new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE))
new_meta.update(
(k, v)
for k, v in old_node.meta.items()
if k in torch.fx.proxy._COPY_META_FIELDS
)
new_meta["from_node"] = new_from_node
else:
new_meta.update(
(k, v)
for k, v in old_node.meta.items()
if k in torch.fx.proxy._COPY_META_FIELDS
)


class Match:
Expand Down Expand Up @@ -263,7 +280,11 @@ def replace_by_example(
)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
_transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta)
_transfer_meta(
new_meta=n.meta,
old_node=self.nodes[0],
pass_name="replace_by_example",
)

ReplacementPatternEntry.replace_with_graph(
self,
Expand Down Expand Up @@ -1069,7 +1090,11 @@ def run_node(self, node: torch.fx.Node) -> Any:
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
_transfer_meta(new_meta=result.meta, old_meta=node.meta)
_transfer_meta(
new_meta=result.meta,
old_node=node,
pass_name="Interpreter_Replacer",
)
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):
Expand Down Expand Up @@ -1401,7 +1426,8 @@ def search_fn_new(*args_new: Any) -> Any:
for n in match.replacement_graph.graph.nodes:
_transfer_meta(
new_meta=n.meta,
old_meta=match.nodes[0].meta,
old_node=match.nodes[0],
pass_name="replacement",
)
return True
return False
Expand Down
21 changes: 21 additions & 0 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def __init__(
self._replace_hooks: List[Callable] = []
self._create_node_hooks: List[Callable] = []
self._erase_node_hooks: List[Callable] = []
# Used to remove hooks from deepcopied graph modules within a context manager.
self._deepcopy_hooks: List[Callable] = []

# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
Expand Down Expand Up @@ -888,6 +890,7 @@ def __deepcopy__(self, memo):
"_replace_hooks",
"_create_node_hooks",
"_erase_node_hooks",
"_deepcopy_hooks",
]
for attr in extra_preserved_attrs:
if attr in self.__dict__:
Expand All @@ -896,6 +899,8 @@ def __deepcopy__(self, memo):
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
setattr(res, attr_name, attr)
for hook in self._deepcopy_hooks:
hook(res)
return res

def __copy__(self):
Expand Down Expand Up @@ -1002,6 +1007,22 @@ def _unregister_erase_node_hook(self, f):
assert callable(f), "erase_node hook must be a callable."
self._erase_node_hooks.remove(f)

def _register_deepcopy_hook(self, f):
"""
Takes a callable which will be called when we deepcopy this graph module. The
callable takes the resulting deepcopied graph module.
"""
assert callable(f), "deepcopy hook must be a callable."
self._deepcopy_hooks.append(f)

def _unregister_deepcopy_hook(self, f):
"""
Takes a callable which was previously registered to be called after deepcopy.
This function will unregister that callable so it is no longer invoked on deepcopy.
"""
assert callable(f), "deepcopy hook must be a callable."
self._deepcopy_hooks.remove(f)


# workarounds for issues in __torch_function__

Expand Down
Loading