-
Notifications
You must be signed in to change notification settings - Fork 24.1k
[AOTAutograd] tweak min-cut partitioner to avoid saving softmax output #126348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/shunting314/145/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
…max output" [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a property-based test in test/inductor/test_perf.py
. I'd prefer that to the test currently in the PR.
Also, obviously, do a perf run.
With the new heuristics, the backward runs slower and we end up with roughly neutral perf overall for llm.c. The reason is the kernel computing gradient of softmax input in the backward pass picks a sub-optimal triton config. #126477 fixes that, and now we have the 4ms saving as estimated in the summary for llm.c:
|
…ftmax output" Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing 1. run matmul to get softmax_input 2. load softmax_input to compute max per row. 3. load softmax_input to compute sum per row 4. load softmax_input, normalize it and save the result to softmax_output Step 4 is inefficient since a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving. To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation. The fix is suggested by Chillee to let `dist_from_bw` play a less important role. [ghstack-poisoned]
Perf run shows 5 seconds compilation time regress for TIMM link. I'll need debug where that comes from. |
…ftmax output" Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing 1. run matmul to get softmax_input 2. load softmax_input to compute max per row. 3. load softmax_input to compute sum per row 4. load softmax_input, normalize it and save the result to softmax_output Step 4 is inefficient since a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving. To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation. The fix is suggested by Chillee to let `dist_from_bw` play a less important role. [ghstack-poisoned]
re-request when ready |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
C = 768 | ||
V = 50257 | ||
|
||
linear = nn.Linear(C, V, bias=False, dtype=torch.bfloat16).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, may I suggest we replace "cuda"
with the GPU_TYPE in this case? XPU also meet the has_triton
requirement. You may also skip if not torch.cuda.is_available() for this case. Thanks.
Stack from ghstack (oldest at bottom):
Right now the linear + cross entropy loss operation (usually to be the last part of a transformer model) does the following thing
Step 4 is inefficient since
a. in the fwd pass, only a small slice of the softmax_output tensor is need to compute NLLLoss. Materializing the whole tensor is an overkill
b. in the backward pass, we need the whole softmax_output, but it can be recompute from softmax_input
If we skip saving softmax_output, we would have perf wins since this is the largest tensor in the network. For llm.c, the size is batch_size * sequence_length * vocab_size * item_size ~= 32 * 1024 * 50257 * 2 ~= 3GB. Simply read/write such large tensor need ~2ms in A100. If we recompute softmax_output, we save 1 load for softmax_input and 1 store for softmax_output, which would result in ~4ms saving.
To avoid saving the softmax_output we need make sure the min cut partitioner decides to recompute it based on softmax_input and the max/sum tensor (which is small) computed in step 2 and 3. This is not happening currently since the min cut partitioner over-estimate the cost of recomputation.
The fix is suggested by @Chillee to let
dist_from_bw
play a less important role.