Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Jul 18, 2025

Description

As pointed out by this issue, the current FLOPs calculation does not consider the impact from causal masks. This PR fixes this issue.

FIXES: b/431892390

Tests

Model Max Target Length TFLOPs Type Before After
llama2-7b 8192 Total 1721.22 1510.11
Learnable Weight 1299.00 1299.00
Attention 422.21 211.11
deepseek2-16b 8192 Total 704.60 593.27
Learnable Weight 481.95 481.95
Attention 222.65 111.33
llama3-8b 8192 Total 1897.69 1686.58
Learnable Weight 1475.48 1475.48
Attention 422.21 211.11
llama4-17b-16e 8192 Total 3964.49 3568.67
Learnable Weight 3172.84 3172.84
Attention 791.65 395.82

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@NuojCheng NuojCheng force-pushed the chengnuojin/mask_flops branch from 77be52c to 3f17fed Compare July 18, 2025 20:17
@NuojCheng NuojCheng marked this pull request as ready for review July 18, 2025 20:39
@NuojCheng NuojCheng changed the title Divide FLOPs by due to causal masks Divide FLOPs by two due to causal masks Jul 18, 2025
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you run some quick tests similar like this PR, so we could have some idea on flops impact?

@NuojCheng NuojCheng force-pushed the chengnuojin/mask_flops branch 2 times, most recently from 7e3706e to 30e11f2 Compare July 21, 2025 06:20
@NuojCheng NuojCheng requested review from RissyRan and gobbleturk July 21, 2025 07:00
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @hengtaoguo @aireenmei @gagika you will need to add a config to opt out this calculation for bi-directional mask in Llama4

Copy link
Collaborator

@RissyRan RissyRan Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we keep this as one line for easy copy/paste?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but it won't pass linter test. Maybe using short-URL? Does google have this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya I'd keep as one line, just add

pylint: disable=line-too-long

to the previous line

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks noujin!

@NuojCheng NuojCheng force-pushed the chengnuojin/mask_flops branch from 30e11f2 to a4c0172 Compare July 22, 2025 00:02
@copybara-service copybara-service bot merged commit 2adc3ba into main Jul 22, 2025
18 checks passed
@copybara-service copybara-service bot deleted the chengnuojin/mask_flops branch July 22, 2025 17:07
@gobbleturk gobbleturk mentioned this pull request Jul 27, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants