Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Apr 12, 2025

Stack from ghstack (oldest at bottom):

Fixes #150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151154

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bf0eb70 with merge base 2653498 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@anijain2305
Copy link
Contributor Author

Asked for review too early. CI failures say that something is really wrong. Converting to draft.

@anijain2305 anijain2305 removed the request for review from jansel April 12, 2025 04:47
Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

Fixes #150994

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

Fixes #150994

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

Fixes #150994

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 12, 2025
Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

ghstack-source-id: 7d03594
Pull Request resolved: #151154
@anijain2305 anijain2305 added topic: not user facing topic category keep-going Don't stop on first failure, keep running tests until the end labels Apr 12, 2025
Fixes #150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]


LayoutLMForSequenceClassification,pass,5
LayoutLMForSequenceClassification,pass,6
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both GPT2 and Layout are correct increases. There is a legit recompilation

image

we were not guarding on self.config.problem_type which is None in the first run, and initiailaized there.

Fixes #150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 12, 2025
Will add test in a follow up PR, because this PR might need to be
cherry-picked to 2.7. So keeping the PR very simple to cherry-pick

ghstack-source-id: 4a519e6
Pull Request resolved: #151154
@anijain2305 anijain2305 marked this pull request as ready for review April 12, 2025 17:37
@anijain2305 anijain2305 added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 12, 2025
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@anijain2305
Copy link
Contributor Author

@pytorchbot cherry-pick --onto release/2.7 -c critical

@pytorchbot
Copy link
Collaborator

Cherry picking #151154

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 7b1a2373e87f5a2a3582cb4e37e47a7c544edf7a returned non-zero exit code 1

Auto-merging test/dynamo/test_misc.py
Auto-merging torch/_dynamo/guards.py
Auto-merging torch/_dynamo/source.py
Auto-merging torch/_dynamo/variables/misc.py
CONFLICT (content): Merge conflict in torch/_dynamo/variables/misc.py
Auto-merging torch/csrc/dynamo/guards.cpp
error: could not apply 7b1a2373e87... [dynamo][super variable] Fix bug to use correct source (#151154)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@ydshieh
Copy link

ydshieh commented Apr 14, 2025

Hi @anijain2305 Thank you a lot for the fix. If I want to try the fix, what would the best approach? Like a wheel for torch 2.7 RC with this fix included? Or checkout to branch/tag release/2.7 and build torch from there? I can't see release/2.7 with this fix applied despite I saw you commented cherry-pick --onto release/2.7 -c critical

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
Fixes pytorch#150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

Pull Request resolved: pytorch#151154
Approved by: https://github.com/jansel
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Fixes pytorch#150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

Pull Request resolved: pytorch#151154
Approved by: https://github.com/jansel
@gante
Copy link

gante commented Apr 24, 2025

👋 I have the same question as @ydshieh -- what's the best way to try the fix?

We noticed that torch.compile on HF VLMs/LLMs, with torch 2.7, results in large slowdowns, so we're pinning the max version to torch 2.6 for now: huggingface/transformers#37760

@atalman atalman added this to the 2.7.1 milestone Apr 25, 2025
@atalman
Copy link
Contributor

atalman commented Apr 25, 2025

adding to 2.7.1

@anijain2305
Copy link
Contributor Author

@gante @ydshieh We apologize for the regression with 2.7. Please allow us till Monday to check if this PR indeed resolves the regression. We can figure out how to proceed then.

@ydshieh
Copy link

ydshieh commented Apr 29, 2025

Hi, share an observation I found. Today I tried to update our docker images to use torch 2.7(+cpu), and find we find the same issues happens for 6 tests, see this run

For example: running

python -m pytest -v tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_generate_compilation_all_outputs

gives

FAILED tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_generate_compilation_all_outputs


 - torch._dynamo.exc.Unsupported: Unexpected type in sourceless builder transformers.models.gemma3.configuration_gemma3.Gemma3TextConfig

from user code:
   File "/usr/local/lib/python3.9/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1320, in forward
    causal_mask = self._update_causal_mask(
  File "/usr/local/lib/python3.9/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1115, in _update_causal_mask
    if self.config.text_config._attn_implementation == "flash_attention_2":
  File "/usr/local/lib/python3.9/site-packages/transformers/configuration_utils.py", line 211, in __getattribute__
    return super().__getattribute__(key)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
FAILED tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_generate_compile_model_forward - torch._dynamo.exc.Unsupported: Unexpected type in sourceless builder transformers.models.gemma3.configuration_gemma3.Gemma3TextConfig

from user code:
   File "/usr/local/lib/python3.9/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1320, in forward
    causal_mask = self._update_causal_mask(
  File "/usr/local/lib/python3.9/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1115, in _update_causal_mask
    if self.config.text_config._attn_implementation == "flash_attention_2":
  File "/usr/local/lib/python3.9/site-packages/transformers/configuration_utils.py", line 211, in __getattribute__
    return super().__getattribute__(key)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
===== 2 failed, 1267 passed, 406 skipped, 46 warnings in 143.60s (0:02:23) =====

@anijain2305
Copy link
Contributor Author

@ydshieh @gante We are cherry-picking this PR (along with the other cudagraph bug PR) for 2.7.1 - #152774

I have tested above test passes with the PR built on top of 2.7 (and I verified that if failed on 2.7). Once the PR is merged to a release branch, it would be really helpful to do some extra testing from your side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants