Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Prev Previous commit
Next Next commit
Update on "[WIP][FX] Add Interpreter and Transformer"
1. Codify the interpreter pattern that was used, for ex, in [shape prop](https://github.com/pytorch/pytorch/blob/c3b4b2062748ed70e689f9f9ffe670e6fa20a071/torch/fx/experimental/shape_prop.py#L7). This is a pattern users will likely find repeatedly useful so it's good to have it first-class in the API. As seen in the test cases, this makes it incredibly easy to do simple analysis and transforms like swapping out individual nodes

2. Add a small `Transformer` class that is an `Interpreter` but provides a `transform` method that yields a GraphModule from interpreting the code

TODO:

- [ ] Write docstrings
- [ ] Write docs
- [ ] Look for more places this can be used

Differential Revision: [D25880330](https://our.internmc.facebook.com/intern/diff/D25880330)

[ghstack-poisoned]
  • Loading branch information
James Reed committed Jan 26, 2021
commit c5fa07bfa55f36e1aba410902be2df3e1765c09e
2 changes: 0 additions & 2 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,8 +1065,6 @@ def call_method(self, n : Node) -> Any:
return super().call_method(n)

transformed = NegSigmSwapXformer(gm).transform()
print(gm.graph)
print(transformed.graph)
input = torch.randn(3, 4)
self.assertEqual(transformed(input), torch.neg(input).sigmoid())

Expand Down
72 changes: 64 additions & 8 deletions torch/fx/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def fn(x):
gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
assert torch.testing.assert_allclose(result, torch.neg(input).sigmoid())
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())

Args:
module (GraphModule): The module to be executed
Expand Down Expand Up @@ -263,30 +263,86 @@ def load_arg(n_arg : Node) -> Any:
return self.env[n_arg]
return map_arg(args, load_arg)

class TransformerTracer(Tracer):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
class Transformer(Interpreter):
"""
``Transformer`` is a special type of interpreter that produces a
new ``Module``. It exposes a ``transform()`` method that returns
the transformed ``Module``. ``Transformer`` does not require
arguments to run, as ``Interpreter`` does. ``Transformer`` works
entirely symbolically.

def is_leaf_module(self, _, __) -> bool:
return True
Example:
Suppose we want to swap all instances of ``torch.neg`` with
``torch.sigmoid`` and vice versa (including their ``Tensor``
method equivalents). We could subclass ``Transformer`` like so::

class Transformer(Interpreter):
class NegSigmSwapXformer(Transformer):
def call_function(self, n : Node) -> Any:
if n.target == torch.sigmoid:
n = copy.copy(n)
n.target = torch.neg
return super().call_function(n)

def call_method(self, n : Node) -> Any:
if n.target == 'neg':
n = copy.copy(n)
n.target = 'sigmoid'
return super().call_method(n)

def fn(x):
return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid())

Args:
module (GraphModule): The ``Module`` to be transformed.
"""
def __init__(self, module):
super().__init__(module)
self.new_graph = Graph()
class TransformerTracer(Tracer):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph

def is_leaf_module(self, _, __) -> bool:
return True
self.tracer = TransformerTracer(self.new_graph)
self.tracer.root = module

def placeholder(self, n : Node) -> Proxy:
"""
Execute a ``placeholder`` node. In ``Transformer``, this is
overridden to insert a new ``placeholder`` into the output
graph.

Args:
n (Node): The placeholder node to execute
"""
assert isinstance(n.target, str)
return Proxy(self.new_graph.placeholder(n.target), self.tracer)

def get_attr(self, n : Node) -> Proxy:
"""
Execute a ``get_attr`` node. In ``Transformer``, this is
overridden to insert a new ``get_attr`` node into the output
graph.

Args:
n (Node): The get_attr node to execute
"""
assert isinstance(n.target, str)
return Proxy(self.new_graph.get_attr(n.target), self.tracer)

def transform(self) -> GraphModule:
"""
Transform ``self.module`` and return the transformed
``GraphModule``.
"""
result = super().run()
if result is not None:
assert isinstance(result, Proxy)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.