Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[AOTAutograd] tweak min-cut partitioner to avoid saving softmax output #126348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/shunting314/145/base
Choose a base branch
from

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented May 15, 2024

Stack from ghstack (oldest at bottom):

Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing

  1. run matmul to get softmax_input
  2. load softmax_input to compute max per row.
  3. load softmax_input to compute sum per row
  4. load softmax_input, normalize it and save the result to softmax_output

Step 4 is inefficient since
a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill
b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input

If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving.

To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation.

The fix is suggested by @Chillee to let dist_from_bw play a less important role.

Copy link

pytorch-bot bot commented May 15, 2024

shunting314 added a commit that referenced this pull request May 15, 2024
@shunting314 shunting314 changed the title [AOTAutograd] tweak min-cut partitioner to avoid save softmax output [AOTAutograd] tweak min-cut partitioner to avoid saving softmax output May 15, 2024
@shunting314 shunting314 requested a review from eellison May 15, 2024 23:10
@ezyang ezyang removed their request for review May 16, 2024 00:16
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a property-based test in test/inductor/test_perf.py. I'd prefer that to the test currently in the PR.

Also, obviously, do a perf run.

@shunting314
Copy link
Contributor Author

With the new heuristics, the backward runs slower and we end up with roughly neutral perf overall for llm.c. The reason is the kernel computing gradient of softmax input in the backward pass picks a sub-optimal triton config. #126477 fixes that, and now we have the 4ms saving as estimated in the summary for llm.c:

Latency: 194.20ms -> 190.15ms ,
Tokens/s 168.7K -> 172.3 K (2% improvement)

…ftmax output"


Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing
1. run matmul to get softmax_input
2. load softmax_input to compute max per row.
3. load softmax_input to compute sum per row
4. load softmax_input, normalize it and save the result to softmax_output

Step 4 is inefficient since
a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill
b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input

If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving.

To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation.

The fix is suggested by Chillee  to let `dist_from_bw` play a less important role.






[ghstack-poisoned]
@shunting314
Copy link
Contributor Author

shunting314 commented May 21, 2024

Also, obviously, do a perf run.

Perf run shows 5 seconds compilation time regress for TIMM link. I'll need debug where that comes from.

…ftmax output"


Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing
1. run matmul to get softmax_input
2. load softmax_input to compute max per row.
3. load softmax_input to compute sum per row
4. load softmax_input, normalize it and save the result to softmax_output

Step 4 is inefficient since
a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill
b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input

If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving.

To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation.

The fix is suggested by Chillee  to let `dist_from_bw` play a less important role.






[ghstack-poisoned]
shunting314 added a commit that referenced this pull request May 21, 2024
@eellison
Copy link
Contributor

re-request when ready

@eellison eellison removed their request for review June 13, 2024 21:12
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 12, 2024
@github-actions github-actions bot closed this Sep 11, 2024
@shunting314 shunting314 removed the Stale label Sep 12, 2024
@shunting314 shunting314 reopened this Sep 12, 2024
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

C = 768
V = 50257

linear = nn.Linear(C, V, bias=False, dtype=torch.bfloat16).cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, may I suggest we replace "cuda" with the GPU_TYPE in this case? XPU also meet the has_triton requirement. You may also skip if not torch.cuda.is_available() for this case. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants