Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ecb672b

Browse files
committed
Only print dde partial fx graph for export
Lazos correctly pointed out this doesn't make sense for compile since we graph break in compile. This results in tons of unwanted user log spew. We do want this in export though since it's drastiaclly reduced the support load for DDEs. This PR does the refactor to keep it in export but remove it from compile ghstack-source-id: db17302 Pull Request resolved: #149831
1 parent 2b848ab commit ecb672b

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

torch/_dynamo/symbolic_convert.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,15 +1342,13 @@ def run(self):
13421342
raise
13431343
except RuntimeError as e:
13441344
if hasattr(e, "msg") and "Data-dependent" in e.msg:
1345-
print(
1346-
"\n"
1347-
+ torch.fx.GraphModule(
1348-
self.output.nn_modules, self.output.graph
1349-
).print_readable(
1350-
print_output=False, include_stride=True, include_device=True
1351-
),
1352-
file=sys.stderr,
1345+
readable_graph = torch.fx.GraphModule(
1346+
self.output.nn_modules, self.output.graph
1347+
).print_readable(
1348+
print_output=False, include_stride=True, include_device=True
13531349
)
1350+
e.partial_fx_graph = readable_graph # type: ignore[attr-defined]
1351+
raise
13541352

13551353
raise
13561354
except Exception as e:

torch/export/_trace.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import logging
77
import re
8+
import sys
89
import time
910
import warnings
1011
from contextlib import contextmanager, nullcontext
@@ -1096,6 +1097,13 @@ def wrapper(*args, **kwargs):
10961097
message=str(e),
10971098
flags=_EXPORT_FLAGS,
10981099
)
1100+
1101+
if hasattr(e, "partial_fx_graph"):
1102+
print(
1103+
e.partial_fx_graph,
1104+
file=sys.stderr,
1105+
)
1106+
10991107
raise e
11001108
finally:
11011109
_EXPORT_FLAGS = None

torch/fx/_symbolic_trace.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import inspect
88
import math
99
import os
10-
import sys
1110
import warnings
1211
from itertools import chain
1312
from types import CodeType, FunctionType, ModuleType
@@ -843,14 +842,12 @@ def forward(*args, **kwargs):
843842
self.submodule_paths = None
844843
except RuntimeError as e:
845844
if isinstance(e.args[0], str) and "data-dependent" in e.args[0]:
846-
print(
847-
"\n"
848-
+ self.graph.python_code(
849-
root_module="self",
850-
verbose=True,
851-
).src,
852-
file=sys.stderr,
853-
)
845+
partial_fx_graph = self.graph.python_code(
846+
root_module="self",
847+
verbose=True,
848+
).src
849+
e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
850+
raise
854851

855852
raise
856853
finally:

0 commit comments

Comments
 (0)