Thanks to visit codestin.com
Credit goes to github.com

Skip to content

more fixes for post-training llama4 #37329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2025

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Apr 7, 2025

What does this PR do?

Along with the changes in #37319 , these are also needed to train llama4.

The guard on the past_key_values cache is because that will be None during training.

this is a LoRA training on Scout
Screenshot 2025-04-06 at 11 59 57 PM

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @SunMarc
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot marked this pull request as draft April 7, 2025 03:49
Copy link

github-actions bot commented Apr 7, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@winglian winglian marked this pull request as ready for review April 7, 2025 03:57
@@ -729,6 +729,7 @@ def forward(
)
return output if return_dict else output.to_tuple()

@torch.compiler.disable # the operations in this method are not compilable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed 😢 but it should be done outside model forward

@@ -779,7 +780,7 @@ def _update_causal_mask(
attention_mask = make_flex_block_causal_mask(
attention_mask,
query_length=sequence_length,
key_length=past_key_values.get_max_cache_shape(),
key_length=past_key_values.get_max_cache_shape() if past_key_values else sequence_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
key_length=past_key_values.get_max_cache_shape() if past_key_values else sequence_length,
key_length=past_key_values.get_max_cache_shape(),

there was#37327 that should have already fixed this part!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the guard, there still could be issues, e.g. when passing use_cache=False (as past_kv gets None).

I don't think sequence length is correct here tho, especially when decoding (even if it's possibly broken atm, we don't need to make it even more broken :D). Wouldn't it make more sense to use target_length here? Then we could also avoid the guard. Needs to fix this probably tho:

if past_key_values is not None and past_key_values.is_compileable:
target_length = past_key_values.get_max_cache_shape

target_length = past_key_values.get_max_cache_shape()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah actually for shorter context this might be better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because hybrid cache would be smaller than max chunk

Copy link
Contributor Author

@winglian winglian Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't call .get_max_cache_shape() on past_key values if it's None though, perhaps this instead?

key_length=past_key_values.get_max_cache_shape() if past_key_values else None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah flex attention forward. Its kinda expected because target length is not aware of the cache max allocated length basically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, my use case is training. We should probably put this PR on hold for the minute until I can do some more verification of the correct fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what default value would you suggest in case past_key_values == None then?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For training its just gonna be sequence length

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Funnily enough, this issue only arises on torch==2.6.0. On 2.5.1, I couldn't reproduce #37329 (comment)

@winglian winglian force-pushed the llama4-fixes-part2 branch from 2881abe to dd06245 Compare April 7, 2025 10:45
@vasqu vasqu mentioned this pull request Apr 7, 2025
4 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will debug but much needed for the patch thanks!

@ArthurZucker
Copy link
Collaborator

I'll fix the last issue should be alright

@ArthurZucker ArthurZucker merged commit b54c2f4 into huggingface:main Apr 7, 2025
16 of 18 checks passed
@ArthurZucker ArthurZucker added the for patch Tag issues / labels that should be included in the next patch label Apr 7, 2025
@ArthurZucker
Copy link
Collaborator

The last issue is because when you are training you should set use_cache=False if you don't want to use the cache. If you have it set to True, it will create a dynamic cache, which then has not max length, which means key len is the same as input sequence length

@ArthurZucker
Copy link
Collaborator

The issue was in the chunked_mask @winglian not the attention mask now

@vasqu
Copy link
Contributor

vasqu commented Apr 7, 2025

I think the cache isn't the issue in this case. It's the chunked attention (as you also point out to) which is why I opened #37351 originally

I added some debugging code and this guard (which follows eager for example):
https://github.com/vasqu/transformers/blob/5f9b6589fe852a14887d9108dffd3426616cff57/src/transformers/models/llama4/modeling_llama4.py#L786

(which doesn't solve the issue imo, but is good enough for now ig)

ArthurZucker pushed a commit that referenced this pull request Apr 7, 2025
* more fixes for post-training llama4

* use target_length instead of guearded past_key_values
cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
* more fixes for post-training llama4

* use target_length instead of guearded past_key_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants