Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit abeca58

Browse files
authored
Update megatron fused softmax follow megatron-lm (#1539)
* Update megatron fused softmax follow megatron-lm Signed-off-by: Yu Yao <[email protected]> * Add mask=None support in scaled_masked_softmax Signed-off-by: Yu Yao <[email protected]> * Update setup.py for scaled_softmax_cuda Signed-off-by: Yu Yao <[email protected]> * Add tests for fused_scale_softmax (mask=None) Signed-off-by: Yu Yao <[email protected]> * Assert grad equal in fused softmax test Signed-off-by: Yu Yao <[email protected]> * Revert "Assert grad equal in fused softmax test" Signed-off-by: Yu Yao <[email protected]> Signed-off-by: Yu Yao <[email protected]> Co-authored-by: Yu Yao <[email protected]>
1 parent c216175 commit abeca58

7 files changed

Lines changed: 535 additions & 9 deletions

File tree

apex/transformer/functional/fused_softmax.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,14 @@ def backward(ctx, output_grads):
9393

9494
def scaled_masked_softmax(inputs, mask, scale):
9595
# input is 4D tensor (b, np, sq, sk)
96-
args = _cast_if_autocast_enabled(inputs, mask, scale)
97-
with torch.cuda.amp.autocast(enabled=False):
98-
return ScaledMaskedSoftmax.apply(*args)
96+
if mask is not None:
97+
args = _cast_if_autocast_enabled(inputs, mask, scale)
98+
with torch.cuda.amp.autocast(enabled=False):
99+
return ScaledMaskedSoftmax.apply(*args)
100+
else:
101+
args = _cast_if_autocast_enabled(inputs, scale)
102+
with torch.cuda.amp.autocast(enabled=False):
103+
return ScaledSoftmax.apply(*args)
99104

100105

101106
class GenericScaledMaskedSoftmax(torch.autograd.Function):
@@ -125,6 +130,37 @@ def generic_scaled_masked_softmax(inputs, mask, scale):
125130
return GenericScaledMaskedSoftmax.apply(*args)
126131

127132

133+
class ScaledSoftmax(torch.autograd.Function):
134+
"""
135+
Fused operation which performs following two operations in sequence
136+
1. Scale the tensor.
137+
2. Perform softmax.
138+
"""
139+
140+
@staticmethod
141+
def forward(ctx, inputs, scale):
142+
import scaled_softmax_cuda
143+
144+
scale_t = torch.tensor([scale])
145+
146+
softmax_results = scaled_softmax_cuda.forward(
147+
inputs, scale_t[0]
148+
)
149+
ctx.save_for_backward(softmax_results, scale_t)
150+
return softmax_results
151+
152+
@staticmethod
153+
def backward(ctx, output_grads):
154+
import scaled_softmax_cuda
155+
156+
softmax_results, scale_t = ctx.saved_tensors
157+
158+
input_grads = scaled_softmax_cuda.backward(
159+
output_grads, softmax_results, scale_t[0]
160+
)
161+
return input_grads, None, None
162+
163+
128164
class FusedScaleMaskSoftmax(torch.nn.Module):
129165
"""
130166
fused operation: scaling + mask + softmax
@@ -191,14 +227,14 @@ def is_kernel_available(self, mask, b, np, sq, sk):
191227
and self.input_in_float16 # input must be fp16
192228
and (
193229
self.attn_mask_type == AttnMaskType.causal
194-
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
230+
or self.attn_mask_type == AttnMaskType.padding
195231
)
196-
and 16 < sk <= 2048 # sk must be 16 ~ 2048
232+
and 16 < sk <= 4096 # sk must be 16 ~ 4096
197233
and sq % 4 == 0 # sq must be divisor of 4
198234
and sk % 4 == 0 # sk must be divisor of 4
199235
and attn_batches % 4 == 0 # np * b must be divisor of 4
200236
):
201-
if 0 <= sk <= 2048:
237+
if 0 <= sk <= 4096:
202238
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
203239

204240
if self.attn_mask_type == AttnMaskType.causal:

csrc/megatron/scaled_masked_softmax.h

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
9090
}
9191
}
9292

93+
94+
/*
95+
* Extended softmax (from native aten pytorch) with following additional features
96+
* 1) input scaling
97+
*/
98+
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
99+
__global__ void scaled_softmax_warp_forward(
100+
output_t *dst,
101+
const input_t *src,
102+
const acc_t scale,
103+
int micro_batch_size,
104+
int element_count)
105+
{
106+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
107+
// warp_size of method warp_softmax_forward_kernel.
108+
constexpr int next_power_of_two = 1 << log2_elements;
109+
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
110+
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
111+
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
112+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
113+
114+
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
115+
// gridDim/blockIdx = (seq_len, attn_heads, batches)
116+
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
117+
118+
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
119+
// many batches have to computed within this WARP.
120+
int local_batches = micro_batch_size - first_batch;
121+
if (local_batches > WARP_BATCH)
122+
local_batches = WARP_BATCH;
123+
124+
// there might be multiple batches per warp. compute the index within the batch
125+
int local_idx = threadIdx.x;
126+
127+
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
128+
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
129+
130+
// load data from global memory
131+
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
132+
input_t temp_data[ELEMENTS_PER_LDG_STG];
133+
#pragma unroll
134+
for (int i = 0; i < WARP_BATCH; ++i) {
135+
int batch_element_count = (i >= local_batches) ? 0 : element_count;
136+
137+
#pragma unroll
138+
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
139+
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
140+
141+
if (element_index < batch_element_count) {
142+
int itr_idx = i*element_count+it*WARP_SIZE;
143+
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
144+
145+
#pragma unroll
146+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
147+
elements[i][it + element] = (acc_t)temp_data[element] * scale;
148+
}
149+
} else {
150+
#pragma unroll
151+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
152+
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
153+
}
154+
}
155+
}
156+
}
157+
158+
// compute max_value
159+
acc_t max_value[WARP_BATCH];
160+
#pragma unroll
161+
for (int i = 0; i < WARP_BATCH; ++i) {
162+
max_value[i] = elements[i][0];
163+
#pragma unroll
164+
for (int it = 1; it < WARP_ITERATIONS; ++it) {
165+
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
166+
}
167+
}
168+
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
169+
170+
acc_t sum[WARP_BATCH] { 0.0f };
171+
#pragma unroll
172+
for (int i = 0; i < WARP_BATCH; ++i) {
173+
#pragma unroll
174+
for (int it = 0; it < WARP_ITERATIONS; ++it) {
175+
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
176+
sum[i] += elements[i][it];
177+
}
178+
}
179+
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
180+
181+
// store result
182+
output_t out[ELEMENTS_PER_LDG_STG];
183+
#pragma unroll
184+
for (int i = 0; i < WARP_BATCH; ++i) {
185+
if (i >= local_batches)
186+
break;
187+
#pragma unroll
188+
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
189+
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
190+
if (element_index < element_count) {
191+
#pragma unroll
192+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
193+
out[element] = elements[i][it + element] / sum[i];
194+
}
195+
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
196+
} else {
197+
break;
198+
}
199+
}
200+
}
201+
}
202+
203+
93204
/*
94205
* Extended softmax (from native aten pytorch) with following additional features
95206
* 1) input scaling
@@ -333,6 +444,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
333444
return batches_per_block;
334445
}
335446

447+
template<typename input_t, typename output_t, typename acc_t>
448+
void dispatch_scaled_softmax_forward(
449+
output_t *dst,
450+
const input_t *src,
451+
const input_t scale,
452+
int query_seq_len,
453+
int key_seq_len,
454+
int batches,
455+
int attn_heads)
456+
{
457+
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
458+
if (key_seq_len == 0) {
459+
return;
460+
} else {
461+
int log2_elements = log2_ceil(key_seq_len);
462+
const int next_power_of_two = 1 << log2_elements;
463+
int batch_count = batches * attn_heads * query_seq_len;
464+
465+
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
466+
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
467+
468+
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
469+
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
470+
471+
// use 128 threads per block to maximimize gpu utilization
472+
constexpr int threads_per_block = 128;
473+
474+
int warps_per_block = (threads_per_block / warp_size);
475+
int batches_per_block = warps_per_block * batches_per_warp;
476+
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
477+
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
478+
dim3 threads(warp_size, warps_per_block, 1);
479+
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
480+
switch (log2_elements) {
481+
case 0: // 1
482+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
483+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
484+
break;
485+
case 1: // 2
486+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
487+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
488+
break;
489+
case 2: // 4
490+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
491+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
492+
break;
493+
case 3: // 8
494+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
495+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
496+
break;
497+
case 4: // 16
498+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
499+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
500+
break;
501+
case 5: // 32
502+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
503+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
504+
break;
505+
case 6: // 64
506+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
507+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
508+
break;
509+
case 7: // 128
510+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
511+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
512+
break;
513+
case 8: // 256
514+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
515+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
516+
break;
517+
case 9: // 512
518+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
519+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
520+
break;
521+
case 10: // 1024
522+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
523+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
524+
break;
525+
case 11: // 2048
526+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
527+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
528+
break;
529+
case 12: // 4096
530+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
531+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
532+
break;
533+
default:
534+
break;
535+
}
536+
}
537+
}
538+
336539
template<typename input_t, typename output_t, typename acc_t>
337540
void dispatch_scaled_masked_softmax_forward(
338541
output_t *dst,
@@ -345,7 +548,7 @@ void dispatch_scaled_masked_softmax_forward(
345548
int attn_heads,
346549
int pad_batches)
347550
{
348-
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
551+
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
349552
if (key_seq_len == 0) {
350553
return;
351554
} else {
@@ -417,6 +620,10 @@ void dispatch_scaled_masked_softmax_forward(
417620
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
418621
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
419622
break;
623+
case 12: // 4096
624+
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
625+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
626+
break;
420627
default:
421628
break;
422629
}
@@ -434,7 +641,7 @@ void dispatch_scaled_masked_softmax_backward(
434641
int batches,
435642
int attn_heads)
436643
{
437-
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
644+
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
438645
if (key_seq_len == 0) {
439646
return;
440647
} else {
@@ -505,6 +712,11 @@ void dispatch_scaled_masked_softmax_backward(
505712
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
506713
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
507714
break;
715+
case 12: // 4096
716+
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
717+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
718+
break;
719+
508720
default:
509721
break;
510722
}

csrc/megatron/scaled_masked_softmax_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
4444
const int attn_heads = input.size(1);
4545
const int query_seq_len = input.size(2);
4646
const int key_seq_len = input.size(3);
47-
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
47+
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
4848
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
4949
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
5050
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);

0 commit comments

Comments
 (0)