Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[inductor] [compile async] Don't compile in eager #152507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ChuanqiXu9
Copy link
Contributor

@ChuanqiXu9 ChuanqiXu9 commented Apr 30, 2025

Previously we will compile in eager mode.

This looks not intentional according to the test. There is a check to check the number of compilations (in current process) to be 0. But maybe due to an oversight, the number it checks is always a zero.

In _InProcessFxCompile and _SerializedFxCompile, we increment the number of codegen_and_compile by self, which is a member variable attached to the instance. But in test, we check the number of codegen_and_compile by the class. I think we should increment the number of codegen_and_compile by the class. Then the test will fail now.

See torch/_inductor/compile_fx_async.py for the fix.

CC @aorenste

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link

pytorch-bot bot commented Apr 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152507

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 22 New Failures

As of commit c8a3815 with merge base 9c7b902 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ChuanqiXu9
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 30, 2025
@HDCharles HDCharles requested review from jansel and jamesjwu May 2, 2025 03:50
@HDCharles HDCharles added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 2, 2025
@jansel
Copy link
Contributor

jansel commented May 4, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/152507/head returned non-zero exit code 1

Rebasing (1/1)
Auto-merging torch/_inductor/compile_fx.py
CONFLICT (content): Merge conflict in torch/_inductor/compile_fx.py
error: could not apply c8a3815c6f4... [inductor] [compile] [async] Don't compile in eager
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply c8a3815c6f4... [inductor] [compile] [async] Don't compile in eager

Raised by https://github.com/pytorch/pytorch/actions/runs/14825161606

@aorenste
Copy link
Contributor

aorenste commented May 5, 2025

I'm not sure if I like this change. My original intention for this dict was so we could tell the difference between the stats for different types of compile modes. With this change we can no longer tell the difference between the types/modes of compile.

@jansel jansel requested review from aorenste and jansel and removed request for jansel May 5, 2025 00:43
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like tests are failing, you should also address @aorenste's concerns.

@ChuanqiXu9
Copy link
Contributor Author

I'm not sure if I like this change. My original intention for this dict was so we could tell the difference between the stats for different types of compile modes. With this change we can no longer tell the difference between the types/modes of compile.

Then maybe I misunderstood your idea. I thought it can be a method to decrease the latency of torch.compile.

In my local test, it can reduce the latency of torch.compile by 30%-70% (end to end). I feel this is worthy. How do you feel about to add this idea to be a different compile modes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants