Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@mitkotak
Copy link
Member

@mitkotak mitkotak commented Dec 12, 2024

Currently to enable full torch.compile capability we need to do e3nn.set_optimization_defaults(jit_script_fx=False). I am proposing to turn this off by default.

Update: 01/03/2024

  • New interface will have "jit_mode": "inductor/script/eager"

  • Will retain backward compatibility by mapping "jit_script_fx"=False to "jit_mode": "eager" and "jit_script_fx"=True to "jit_mode": "script" (Will only do this if "jit_mode" not in "inductor/script".

@mitkotak
Copy link
Member Author

mitkotak commented Dec 12, 2024

One annoying thing I am noticing is that if we do

jitted = assert_auto_jitable(model) # torch.jit.scripts

pt2 = torch.compile(model, fullgraph=True)

then it tries to compiles a scripted module compared to

pt2 = torch.compile(model, fullgraph=True)  # torch.jit.scripts

jitted = assert_auto_jitable(model)

which runs without errors

@mitkotak mitkotak force-pushed the script_fx_default_off branch 3 times, most recently from 3208824 to 604ac13 Compare December 12, 2024 21:42
@mitkotak mitkotak force-pushed the script_fx_default_off branch from 604ac13 to 3463d3e Compare December 12, 2024 21:43
if k not in _OPT_DEFAULTS:
raise ValueError(f"Unknown optimization option: {k}")
_OPT_DEFAULTS[k] = v

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should probably be a check in here that jit_mode actually takes a valid value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks done !

Comment on lines 47 to 48
elif opt_defaults["jit_mode"] == "compile":
scriptmod = torch.compile(graphmod, fullgraph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are torch.compiles picklable? Otherwise this branch will make CodeGenMixin classes not picklable anymore (which in general we solved for the plain fx.GraphModule or torch.jit.script cases). Not important to any particular case I can think of, just raising it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of the branch

inp = irreps_in.randn(13, -1)
out = a(inp)

a_pt2 = torch.compile(a, fullgraph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Nitpick) you could always write an assert_torch_compile_works function like assert_auto_jitable to use throughout the tests that does the compile, checks that it's close, etc. Not necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks done !

buffer_type: str
buffer: bytes
if isinstance(smod, fx.GraphModule):
if isinstance(smod, (fx.GraphModule, torch._dynamo.OptimizedModule)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these torch._dynamo.OptimizedModule actually picklable? Only within the same process?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup was able to confirm with a reproducer from Claude

import torch
import pickle
import io

# First, let's create a simple PyTorch model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 5)
    
    def forward(self, x):
        return self.linear(x)

# Create an instance of the model
model = SimpleModel()

# Now, let's optimize the model with torch.compile
# This should wrap it in an OptimizedModule
optimized_model = torch.compile(model)

# Check the type of the optimized model
print(f"Type of optimized model: {type(optimized_model)}")
print(f"Is it an OptimizedModule? {isinstance(optimized_model, torch._dynamo.OptimizedModule)}")

# Try to pickle the optimized model
try:
    # Use a BytesIO object to pickle in memory
    buffer = io.BytesIO()
    pickle.dump(optimized_model, buffer)
    buffer.seek(0)
    
    # Try to unpickle
    unpickled_model = pickle.load(buffer)
    
    print("Successfully pickled and unpickled the OptimizedModule!")
    print(f"Type of unpickled model: {type(unpickled_model)}")
    
    # Check if the unpickled model works
    test_input = torch.randn(2, 10)
    original_output = optimized_model(test_input)
    unpickled_output = unpickled_model(test_input)
    
    # Compare outputs
    print(f"Outputs match: {torch.allclose(original_output, unpickled_output)}")
    
except Exception as e:
    print(f"Failed to pickle/unpickle: {e}")
    print("torch._dynamo.OptimizedModule is NOT picklable")

e3nn/__init__.py Outdated

# Handle the legacy jit_script_fx flag
if k == "jit_script_fx":
_OPT_DEFAULTS[k] = v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this happen after the validation?

More generally, is the rationale for this to make sure that older code that looks at this flag can still find it in the defaults dict? Maybe worth documenting that with a comment if so

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Moved the dict assignment after validation + docs

@mitkotak mitkotak merged commit 0729ee5 into e3nn:main Apr 27, 2025
1 of 3 checks passed
@mitkotak mitkotak changed the title [WIP] Compile on by default Compile on by default Aug 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants