Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ed71996

Browse files
authored
Adds small-batch kernels (#1126)
1 parent c1378e6 commit ed71996

15 files changed

Lines changed: 1584 additions & 200 deletions

apex/contrib/csrc/fmha/fmha_api.cpp

Lines changed: 172 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,6 @@
3030

3131
#include "fmha.h"
3232

33-
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
34-
bool is_training,
35-
cudaStream_t stream);
36-
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
37-
bool is_training,
38-
cudaStream_t stream);
39-
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
40-
bool is_training,
41-
cudaStream_t stream);
42-
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
43-
bool is_training,
44-
cudaStream_t stream);
45-
46-
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
47-
cudaStream_t stream);
48-
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
49-
cudaStream_t stream);
50-
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
51-
cudaStream_t stream);
52-
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
53-
cudaStream_t stream);
54-
5533
void set_params(Fused_multihead_attention_fprop_params &params,
5634
// sizes
5735
const size_t b,
@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
6139
// device pointers
6240
void *qkv_packed_d,
6341
void *cu_seqlens_d,
64-
void *seqlens_d,
6542
void *o_packed_d,
6643
void *s_d,
6744
float p_dropout) {
@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params &params,
7956
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
8057

8158
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
82-
params.seqlens = static_cast<int *>(seqlens_d);
8359

8460
// S = softmax(P)
8561
params.s_ptr = s_d;
@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params &params,
10783
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
10884
}
10985

110-
constexpr uint32_t NUM_HEADS_DIM = 2;
111-
constexpr uint32_t THREE_DIM = 1;
112-
11386
std::vector<at::Tensor>
11487
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
11588
const at::Tensor &cu_seqlens, // b+1
116-
const at::Tensor &seqlens, // b
11789
const float p_dropout,
11890
const int max_seq_len,
11991
const bool is_training,
@@ -149,28 +121,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
149121

150122
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
151123
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
152-
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
153124

154125
TORCH_CHECK(qkv.is_cuda())
155126
TORCH_CHECK(cu_seqlens.is_cuda())
156127

157128
TORCH_CHECK(qkv.is_contiguous())
158129
TORCH_CHECK(cu_seqlens.is_contiguous())
159-
TORCH_CHECK(seqlens.is_contiguous())
160130

161131
TORCH_CHECK(cu_seqlens.dim() == 1);
162-
TORCH_CHECK(seqlens.dim() == 1);
163132
TORCH_CHECK(qkv.dim() == 4);
164133

165134
const auto sizes = qkv.sizes();
166135

167136
TORCH_CHECK(sizes[THREE_DIM] == 3);
168137

169138
const int batch_size = cu_seqlens.numel() - 1;
170-
TORCH_CHECK(seqlens.numel() == batch_size);
171-
const int total = sizes[0];
172-
const int num_heads = sizes[NUM_HEADS_DIM];
173-
const int head_size = sizes[3];
139+
const int total = sizes[TOTAL_DIM];
140+
const int num_heads = sizes[H_DIM];
141+
const int head_size = sizes[D_DIM];
174142
TORCH_CHECK(batch_size > 0);
175143
TORCH_CHECK(head_size == 64);
176144
auto opts = qkv.options();
@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
191159
head_size,
192160
qkv.data_ptr(),
193161
cu_seqlens.data_ptr(),
194-
seqlens.data_ptr(),
195162
ctx.data_ptr(),
196163
s.data_ptr(),
197164
p_dropout);
@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
217184
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
218185
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
219186
const at::Tensor &cu_seqlens, // b+1
220-
const at::Tensor &seqlens, // b
221187
const float p_dropout, // probability to drop
222188
const int max_seq_len // max sequence length to choose the kernel
223189
) {
@@ -247,27 +213,23 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
247213
TORCH_CHECK(dout.dtype() == torch::kFloat16);
248214
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
249215
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
250-
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
251216

252217
TORCH_CHECK(qkv.is_cuda());
253218
TORCH_CHECK(cu_seqlens.is_cuda());
254219

255220
TORCH_CHECK(qkv.is_contiguous());
256221
TORCH_CHECK(cu_seqlens.is_contiguous());
257-
TORCH_CHECK(seqlens.is_contiguous());
258222

259223
TORCH_CHECK(cu_seqlens.dim() == 1);
260-
TORCH_CHECK(seqlens.dim() == 1);
261224
TORCH_CHECK(qkv.dim() == 4);
262225

263226
const auto sizes = qkv.sizes();
264227

265228
TORCH_CHECK(sizes[THREE_DIM] == 3);
266229

267230
const int batch_size = cu_seqlens.numel() - 1;
268-
TORCH_CHECK(seqlens.numel() == batch_size);
269-
const int num_heads = sizes[NUM_HEADS_DIM];
270-
const int head_size = sizes[3];
231+
const int num_heads = sizes[H_DIM];
232+
const int head_size = sizes[D_DIM];
271233
TORCH_CHECK(batch_size > 0);
272234
TORCH_CHECK(head_size == 64);
273235

@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
282244
head_size,
283245
qkv.data_ptr(),
284246
cu_seqlens.data_ptr(),
285-
seqlens.data_ptr(),
286247
dout.data_ptr(), // we set o_ptr to dout
287248
softmax.data_ptr(), // softmax gets overwritten by dP!
288249
p_dropout);
289250

290-
// we're re-using these scales scales
251+
// we're re-using these scales
291252
Data_type acc_type = DATA_TYPE_FP32;
292253
set_alpha(params.scale_bmm1, 1.f, acc_type);
293254
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
298259
return { dqkv, softmax };
299260
}
300261

262+
std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
263+
const at::Tensor &cu_seqlens, // b+1
264+
const float p_dropout,
265+
const int max_seq_len,
266+
const bool is_training,
267+
c10::optional<at::Generator> gen_) {
268+
int seq_len = 512;
269+
auto launch = &run_fmha_fp16_512_64_sm80_nl;
270+
TORCH_CHECK(max_seq_len == seq_len);
271+
272+
constexpr int warps_m = 1;
273+
constexpr int warps_n = 4; // this leads to an upper bound
274+
const int mmas_m = seq_len / 16 / warps_m;
275+
const int mmas_n = seq_len / 16 / warps_n;
276+
// static_assert( mmas_m == 32 );
277+
// static_assert( mmas_n == 4 );
278+
const int elts_per_thread = 8 * mmas_m * mmas_n;
279+
280+
auto stream = at::cuda::getCurrentCUDAStream().stream();
281+
282+
TORCH_CHECK(qkv.is_cuda())
283+
TORCH_CHECK(cu_seqlens.is_cuda())
284+
285+
TORCH_CHECK(qkv.is_contiguous())
286+
TORCH_CHECK(cu_seqlens.is_contiguous())
287+
288+
TORCH_CHECK(cu_seqlens.dim() == 1);
289+
TORCH_CHECK(qkv.dim() == 4);
290+
291+
const auto sizes = qkv.sizes();
292+
293+
TORCH_CHECK(sizes[THREE_DIM] == 3);
294+
295+
const int batch_size = cu_seqlens.numel() - 1;
296+
const int total = sizes[TOTAL_DIM];
297+
const int num_heads = sizes[H_DIM];
298+
const int head_size = sizes[D_DIM];
299+
TORCH_CHECK(batch_size > 0);
300+
TORCH_CHECK(head_size == 64);
301+
auto opts = qkv.options();
302+
303+
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
304+
305+
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
306+
307+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
308+
309+
Fused_multihead_attention_fprop_params params;
310+
311+
set_params(params,
312+
batch_size,
313+
seq_len,
314+
num_heads,
315+
head_size,
316+
qkv.data_ptr(),
317+
cu_seqlens.data_ptr(),
318+
ctx.data_ptr(),
319+
s.data_ptr(),
320+
p_dropout);
321+
322+
// number of times random will be generated per thread, to offset philox counter in thc random
323+
// state
324+
int64_t counter_offset = elts_per_thread;
325+
at::PhiloxCudaState rng_engine_inputs;
326+
327+
if( is_training ) {
328+
// See Note [Acquire lock when using random generators]
329+
std::lock_guard<std::mutex> lock(gen->mutex_);
330+
params.philox_args = gen->philox_cuda_state(counter_offset);
331+
}
332+
int num_chunks = 3;
333+
if(batch_size == 3) {
334+
num_chunks = 2;
335+
}
336+
337+
launch(params, is_training, num_chunks, stream);
338+
339+
return { ctx, s };
340+
}
341+
342+
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
343+
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
344+
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
345+
const at::Tensor &cu_seqlens, // b+1
346+
const float p_dropout, // probability to drop
347+
const int max_seq_len // max sequence length to choose the kernel
348+
) {
349+
350+
auto stream = at::cuda::getCurrentCUDAStream().stream();
351+
352+
TORCH_CHECK(qkv.is_cuda())
353+
TORCH_CHECK(cu_seqlens.is_cuda())
354+
355+
TORCH_CHECK(qkv.is_contiguous())
356+
TORCH_CHECK(cu_seqlens.is_contiguous())
357+
358+
TORCH_CHECK(cu_seqlens.dim() == 1);
359+
360+
TORCH_CHECK(qkv.dim() == 4);
361+
362+
const auto sizes = qkv.sizes();
363+
364+
TORCH_CHECK(sizes[THREE_DIM] == 3);
365+
366+
const int batch_size = cu_seqlens.numel() - 1;
367+
368+
const int total = sizes[TOTAL_DIM];
369+
const int num_heads = sizes[H_DIM];
370+
const int head_size = sizes[D_DIM];
371+
TORCH_CHECK(batch_size > 0);
372+
TORCH_CHECK(head_size == 64);
373+
374+
int seq_len = 512;
375+
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
376+
377+
auto opts = qkv.options();
378+
379+
auto dqkv = torch::empty_like(qkv);
380+
381+
int num_chunks = 2;
382+
if( batch_size == 1 ) {
383+
num_chunks = 4;
384+
}else if( batch_size == 2 ) {
385+
num_chunks = 3;
386+
}
387+
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
388+
389+
Fused_multihead_attention_fprop_params params;
390+
391+
set_params(params,
392+
batch_size,
393+
seq_len,
394+
num_heads,
395+
head_size,
396+
qkv.data_ptr(),
397+
cu_seqlens.data_ptr(),
398+
dout.data_ptr(), // o_ptr = dout
399+
softmax.data_ptr(), // softmax gets overwritten by dP!
400+
p_dropout);
401+
402+
params.dkv_ptr = dkv.data_ptr();
403+
404+
Data_type acc_type = DATA_TYPE_FP32;
405+
set_alpha(params.scale_bmm1, 1.f, acc_type);
406+
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
407+
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
408+
params.dqkv_ptr = dqkv.data_ptr();
409+
410+
launch(params, num_chunks, stream);
411+
412+
//SPLIT-K reduction of num_chunks dK, dV parts
413+
414+
// The equivalent of the following Pytorch code:
415+
// using namespace torch::indexing;
416+
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
417+
// torch::sum_out(view_out, dkv, 1);
418+
419+
const int hidden_size = num_heads * head_size;
420+
fmha_run_noloop_reduce(
421+
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
422+
423+
return { dqkv, softmax, dkv };
424+
}
425+
301426
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
302427
m.doc() = "Fused Multi-head Self-attention for BERT";
303428
m.def("fwd", &mha_fwd, "Forward pass");
304429
m.def("bwd", &mha_bwd, "Backward pass");
430+
m.def("fwd_nl", &mha_fwd_nl, "Forward pass (small-batch)");
431+
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
305432
}

0 commit comments

Comments
 (0)