-
Notifications
You must be signed in to change notification settings - Fork 28.9k
more fixes for post-training llama4 #37329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
@@ -729,6 +729,7 @@ def forward( | |||
) | |||
return output if return_dict else output.to_tuple() | |||
|
|||
@torch.compiler.disable # the operations in this method are not compilable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed 😢 but it should be done outside model forward
@@ -779,7 +780,7 @@ def _update_causal_mask( | |||
attention_mask = make_flex_block_causal_mask( | |||
attention_mask, | |||
query_length=sequence_length, | |||
key_length=past_key_values.get_max_cache_shape(), | |||
key_length=past_key_values.get_max_cache_shape() if past_key_values else sequence_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key_length=past_key_values.get_max_cache_shape() if past_key_values else sequence_length, | |
key_length=past_key_values.get_max_cache_shape(), |
there was#37327 that should have already fixed this part!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the guard, there still could be issues, e.g. when passing use_cache=False
(as past_kv gets None).
I don't think sequence length is correct here tho, especially when decoding (even if it's possibly broken atm, we don't need to make it even more broken :D). Wouldn't it make more sense to use target_length
here? Then we could also avoid the guard. Needs to fix this probably tho:
transformers/src/transformers/models/llama4/modeling_llama4.py
Lines 769 to 770 in debfe90
if past_key_values is not None and past_key_values.is_compileable: | |
target_length = past_key_values.get_max_cache_shape |
target_length = past_key_values.get_max_cache_shape()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah actually for shorter context this might be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because hybrid cache would be smaller than max chunk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't call .get_max_cache_shape()
on past_key values if it's None though, perhaps this instead?
key_length=past_key_values.get_max_cache_shape() if past_key_values else None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah flex attention forward. Its kinda expected because target length is not aware of the cache max allocated length basically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, my use case is training. We should probably put this PR on hold for the minute until I can do some more verification of the correct fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, what default value would you suggest in case past_key_values == None
then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For training its just gonna be sequence length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Funnily enough, this issue only arises on torch==2.6.0
. On 2.5.1, I couldn't reproduce #37329 (comment)
2881abe
to
dd06245
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will debug but much needed for the patch thanks!
I'll fix the last issue should be alright |
The last issue is because when you are training you should set |
The issue was in the |
I think the cache isn't the issue in this case. It's the chunked attention (as you also point out to) which is why I opened #37351 originally I added some debugging code and this guard (which follows eager for example): (which doesn't solve the issue imo, but is good enough for now ig) |
* more fixes for post-training llama4 * use target_length instead of guearded past_key_values
* more fixes for post-training llama4 * use target_length instead of guearded past_key_values
What does this PR do?
Along with the changes in #37319 , these are also needed to train llama4.
The guard on the
past_key_values
cache is because that will be None during training.this is a LoRA training on Scout

Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @SunMarc
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.