-
Notifications
You must be signed in to change notification settings - Fork 447
Divide FLOPs by two due to causal masks #1988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
77be52c to
3f17fed
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you run some quick tests similar like this PR, so we could have some idea on flops impact?
7e3706e to
30e11f2
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! @hengtaoguo @aireenmei @gagika you will need to add a config to opt out this calculation for bi-directional mask in Llama4
MaxText/maxtext_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we keep this as one line for easy copy/paste?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried but it won't pass linter test. Maybe using short-URL? Does google have this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya I'd keep as one line, just add
pylint: disable=line-too-long
to the previous line
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks noujin!
30e11f2 to
a4c0172
Compare
Description
As pointed out by this issue, the current FLOPs calculation does not consider the impact from causal masks. This PR fixes this issue.
FIXES: b/431892390
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):