-
Notifications
You must be signed in to change notification settings - Fork 176
Compile on by default #497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
One annoying thing I am noticing is that if we do jitted = assert_auto_jitable(model) # torch.jit.scripts
pt2 = torch.compile(model, fullgraph=True)then it tries to compiles a scripted module compared to pt2 = torch.compile(model, fullgraph=True) # torch.jit.scripts
jitted = assert_auto_jitable(model)which runs without errors |
3208824 to
604ac13
Compare
604ac13 to
3463d3e
Compare
| if k not in _OPT_DEFAULTS: | ||
| raise ValueError(f"Unknown optimization option: {k}") | ||
| _OPT_DEFAULTS[k] = v | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should probably be a check in here that jit_mode actually takes a valid value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks done !
e3nn/util/codegen/_mixin.py
Outdated
| elif opt_defaults["jit_mode"] == "compile": | ||
| scriptmod = torch.compile(graphmod, fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are torch.compiles picklable? Otherwise this branch will make CodeGenMixin classes not picklable anymore (which in general we solved for the plain fx.GraphModule or torch.jit.script cases). Not important to any particular case I can think of, just raising it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got rid of the branch
tests/nn/activation_test.py
Outdated
| inp = irreps_in.randn(13, -1) | ||
| out = a(inp) | ||
|
|
||
| a_pt2 = torch.compile(a, fullgraph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Nitpick) you could always write an assert_torch_compile_works function like assert_auto_jitable to use throughout the tests that does the compile, checks that it's close, etc. Not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks done !
| buffer_type: str | ||
| buffer: bytes | ||
| if isinstance(smod, fx.GraphModule): | ||
| if isinstance(smod, (fx.GraphModule, torch._dynamo.OptimizedModule)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these torch._dynamo.OptimizedModule actually picklable? Only within the same process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup was able to confirm with a reproducer from Claude
import torch
import pickle
import io
# First, let's create a simple PyTorch model
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
# Create an instance of the model
model = SimpleModel()
# Now, let's optimize the model with torch.compile
# This should wrap it in an OptimizedModule
optimized_model = torch.compile(model)
# Check the type of the optimized model
print(f"Type of optimized model: {type(optimized_model)}")
print(f"Is it an OptimizedModule? {isinstance(optimized_model, torch._dynamo.OptimizedModule)}")
# Try to pickle the optimized model
try:
# Use a BytesIO object to pickle in memory
buffer = io.BytesIO()
pickle.dump(optimized_model, buffer)
buffer.seek(0)
# Try to unpickle
unpickled_model = pickle.load(buffer)
print("Successfully pickled and unpickled the OptimizedModule!")
print(f"Type of unpickled model: {type(unpickled_model)}")
# Check if the unpickled model works
test_input = torch.randn(2, 10)
original_output = optimized_model(test_input)
unpickled_output = unpickled_model(test_input)
# Compare outputs
print(f"Outputs match: {torch.allclose(original_output, unpickled_output)}")
except Exception as e:
print(f"Failed to pickle/unpickle: {e}")
print("torch._dynamo.OptimizedModule is NOT picklable")
e3nn/__init__.py
Outdated
|
|
||
| # Handle the legacy jit_script_fx flag | ||
| if k == "jit_script_fx": | ||
| _OPT_DEFAULTS[k] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this happen after the validation?
More generally, is the rationale for this to make sure that older code that looks at this flag can still find it in the defaults dict? Maybe worth documenting that with a comment if so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Moved the dict assignment after validation + docs
Currently to enable full
torch.compilecapability we need to doe3nn.set_optimization_defaults(jit_script_fx=False). I am proposing to turn this off by default.Update: 01/03/2024
New interface will have
"jit_mode": "inductor/script/eager"Will retain backward compatibility by mapping
"jit_script_fx"=Falseto"jit_mode": "eager"and"jit_script_fx"=Trueto"jit_mode": "script"(Will only do this if"jit_mode"not in"inductor/script".