3030
3131#include " fmha.h"
3232
33- void run_fmha_fp16_128_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
34- bool is_training,
35- cudaStream_t stream);
36- void run_fmha_fp16_256_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
37- bool is_training,
38- cudaStream_t stream);
39- void run_fmha_fp16_384_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
40- bool is_training,
41- cudaStream_t stream);
42- void run_fmha_fp16_512_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
43- bool is_training,
44- cudaStream_t stream);
45-
46- void run_fmha_dgrad_fp16_128_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
47- cudaStream_t stream);
48- void run_fmha_dgrad_fp16_256_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
49- cudaStream_t stream);
50- void run_fmha_dgrad_fp16_384_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
51- cudaStream_t stream);
52- void run_fmha_dgrad_fp16_512_64_sm80 (const Fused_multihead_attention_fprop_params ¶ms,
53- cudaStream_t stream);
54-
5533void set_params (Fused_multihead_attention_fprop_params ¶ms,
5634 // sizes
5735 const size_t b,
@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
6139 // device pointers
6240 void *qkv_packed_d,
6341 void *cu_seqlens_d,
64- void *seqlens_d,
6542 void *o_packed_d,
6643 void *s_d,
6744 float p_dropout) {
@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
7956 params.o_stride_in_bytes = get_size_in_bytes (h * d, data_type);
8057
8158 params.cu_seqlens = static_cast <int *>(cu_seqlens_d);
82- params.seqlens = static_cast <int *>(seqlens_d);
8359
8460 // S = softmax(P)
8561 params.s_ptr = s_d;
@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
10783 set_alpha (params.scale_dropout , params.rp_dropout , data_type);
10884}
10985
110- constexpr uint32_t NUM_HEADS_DIM = 2 ;
111- constexpr uint32_t THREE_DIM = 1 ;
112-
11386std::vector<at::Tensor>
11487mha_fwd (const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
11588 const at::Tensor &cu_seqlens, // b+1
116- const at::Tensor &seqlens, // b
11789 const float p_dropout,
11890 const int max_seq_len,
11991 const bool is_training,
@@ -149,28 +121,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
149121
150122 TORCH_CHECK (qkv.dtype () == torch::kFloat16 );
151123 TORCH_CHECK (cu_seqlens.dtype () == torch::kInt32 );
152- TORCH_CHECK (seqlens.dtype () == torch::kInt32 );
153124
154125 TORCH_CHECK (qkv.is_cuda ())
155126 TORCH_CHECK (cu_seqlens.is_cuda ())
156127
157128 TORCH_CHECK (qkv.is_contiguous ())
158129 TORCH_CHECK (cu_seqlens.is_contiguous ())
159- TORCH_CHECK (seqlens.is_contiguous ())
160130
161131 TORCH_CHECK (cu_seqlens.dim () == 1 );
162- TORCH_CHECK (seqlens.dim () == 1 );
163132 TORCH_CHECK (qkv.dim () == 4 );
164133
165134 const auto sizes = qkv.sizes ();
166135
167136 TORCH_CHECK (sizes[THREE_DIM] == 3 );
168137
169138 const int batch_size = cu_seqlens.numel () - 1 ;
170- TORCH_CHECK (seqlens.numel () == batch_size);
171- const int total = sizes[0 ];
172- const int num_heads = sizes[NUM_HEADS_DIM];
173- const int head_size = sizes[3 ];
139+ const int total = sizes[TOTAL_DIM];
140+ const int num_heads = sizes[H_DIM];
141+ const int head_size = sizes[D_DIM];
174142 TORCH_CHECK (batch_size > 0 );
175143 TORCH_CHECK (head_size == 64 );
176144 auto opts = qkv.options ();
@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
191159 head_size,
192160 qkv.data_ptr (),
193161 cu_seqlens.data_ptr (),
194- seqlens.data_ptr (),
195162 ctx.data_ptr (),
196163 s.data_ptr (),
197164 p_dropout);
@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
217184 const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
218185 at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
219186 const at::Tensor &cu_seqlens, // b+1
220- const at::Tensor &seqlens, // b
221187 const float p_dropout, // probability to drop
222188 const int max_seq_len // max sequence length to choose the kernel
223189) {
@@ -247,27 +213,23 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
247213 TORCH_CHECK (dout.dtype () == torch::kFloat16 );
248214 TORCH_CHECK (softmax.dtype () == torch::kFloat16 );
249215 TORCH_CHECK (cu_seqlens.dtype () == torch::kInt32 );
250- TORCH_CHECK (seqlens.dtype () == torch::kInt32 );
251216
252217 TORCH_CHECK (qkv.is_cuda ());
253218 TORCH_CHECK (cu_seqlens.is_cuda ());
254219
255220 TORCH_CHECK (qkv.is_contiguous ());
256221 TORCH_CHECK (cu_seqlens.is_contiguous ());
257- TORCH_CHECK (seqlens.is_contiguous ());
258222
259223 TORCH_CHECK (cu_seqlens.dim () == 1 );
260- TORCH_CHECK (seqlens.dim () == 1 );
261224 TORCH_CHECK (qkv.dim () == 4 );
262225
263226 const auto sizes = qkv.sizes ();
264227
265228 TORCH_CHECK (sizes[THREE_DIM] == 3 );
266229
267230 const int batch_size = cu_seqlens.numel () - 1 ;
268- TORCH_CHECK (seqlens.numel () == batch_size);
269- const int num_heads = sizes[NUM_HEADS_DIM];
270- const int head_size = sizes[3 ];
231+ const int num_heads = sizes[H_DIM];
232+ const int head_size = sizes[D_DIM];
271233 TORCH_CHECK (batch_size > 0 );
272234 TORCH_CHECK (head_size == 64 );
273235
@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
282244 head_size,
283245 qkv.data_ptr (),
284246 cu_seqlens.data_ptr (),
285- seqlens.data_ptr (),
286247 dout.data_ptr (), // we set o_ptr to dout
287248 softmax.data_ptr (), // softmax gets overwritten by dP!
288249 p_dropout);
289250
290- // we're re-using these scales scales
251+ // we're re-using these scales
291252 Data_type acc_type = DATA_TYPE_FP32;
292253 set_alpha (params.scale_bmm1 , 1 .f , acc_type);
293254 set_alpha (params.scale_softmax , 1 .f / sqrtf (head_size), acc_type);
@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
298259 return { dqkv, softmax };
299260}
300261
262+ std::vector<at::Tensor> mha_fwd_nl (const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
263+ const at::Tensor &cu_seqlens, // b+1
264+ const float p_dropout,
265+ const int max_seq_len,
266+ const bool is_training,
267+ c10::optional<at::Generator> gen_) {
268+ int seq_len = 512 ;
269+ auto launch = &run_fmha_fp16_512_64_sm80_nl;
270+ TORCH_CHECK (max_seq_len == seq_len);
271+
272+ constexpr int warps_m = 1 ;
273+ constexpr int warps_n = 4 ; // this leads to an upper bound
274+ const int mmas_m = seq_len / 16 / warps_m;
275+ const int mmas_n = seq_len / 16 / warps_n;
276+ // static_assert( mmas_m == 32 );
277+ // static_assert( mmas_n == 4 );
278+ const int elts_per_thread = 8 * mmas_m * mmas_n;
279+
280+ auto stream = at::cuda::getCurrentCUDAStream ().stream ();
281+
282+ TORCH_CHECK (qkv.is_cuda ())
283+ TORCH_CHECK (cu_seqlens.is_cuda ())
284+
285+ TORCH_CHECK (qkv.is_contiguous ())
286+ TORCH_CHECK (cu_seqlens.is_contiguous ())
287+
288+ TORCH_CHECK (cu_seqlens.dim () == 1 );
289+ TORCH_CHECK (qkv.dim () == 4 );
290+
291+ const auto sizes = qkv.sizes ();
292+
293+ TORCH_CHECK (sizes[THREE_DIM] == 3 );
294+
295+ const int batch_size = cu_seqlens.numel () - 1 ;
296+ const int total = sizes[TOTAL_DIM];
297+ const int num_heads = sizes[H_DIM];
298+ const int head_size = sizes[D_DIM];
299+ TORCH_CHECK (batch_size > 0 );
300+ TORCH_CHECK (head_size == 64 );
301+ auto opts = qkv.options ();
302+
303+ auto ctx = torch::empty ({ total, num_heads, head_size }, opts);
304+
305+ auto s = torch::empty ({ batch_size, num_heads, seq_len, seq_len }, opts);
306+
307+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator ());
308+
309+ Fused_multihead_attention_fprop_params params;
310+
311+ set_params (params,
312+ batch_size,
313+ seq_len,
314+ num_heads,
315+ head_size,
316+ qkv.data_ptr (),
317+ cu_seqlens.data_ptr (),
318+ ctx.data_ptr (),
319+ s.data_ptr (),
320+ p_dropout);
321+
322+ // number of times random will be generated per thread, to offset philox counter in thc random
323+ // state
324+ int64_t counter_offset = elts_per_thread;
325+ at::PhiloxCudaState rng_engine_inputs;
326+
327+ if ( is_training ) {
328+ // See Note [Acquire lock when using random generators]
329+ std::lock_guard<std::mutex> lock (gen->mutex_ );
330+ params.philox_args = gen->philox_cuda_state (counter_offset);
331+ }
332+ int num_chunks = 3 ;
333+ if (batch_size == 3 ) {
334+ num_chunks = 2 ;
335+ }
336+
337+ launch (params, is_training, num_chunks, stream);
338+
339+ return { ctx, s };
340+ }
341+
342+ std::vector<at::Tensor> mha_bwd_nl (const at::Tensor &dout, // total x num_heads, x head_size
343+ const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
344+ at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
345+ const at::Tensor &cu_seqlens, // b+1
346+ const float p_dropout, // probability to drop
347+ const int max_seq_len // max sequence length to choose the kernel
348+ ) {
349+
350+ auto stream = at::cuda::getCurrentCUDAStream ().stream ();
351+
352+ TORCH_CHECK (qkv.is_cuda ())
353+ TORCH_CHECK (cu_seqlens.is_cuda ())
354+
355+ TORCH_CHECK (qkv.is_contiguous ())
356+ TORCH_CHECK (cu_seqlens.is_contiguous ())
357+
358+ TORCH_CHECK (cu_seqlens.dim () == 1 );
359+
360+ TORCH_CHECK (qkv.dim () == 4 );
361+
362+ const auto sizes = qkv.sizes ();
363+
364+ TORCH_CHECK (sizes[THREE_DIM] == 3 );
365+
366+ const int batch_size = cu_seqlens.numel () - 1 ;
367+
368+ const int total = sizes[TOTAL_DIM];
369+ const int num_heads = sizes[H_DIM];
370+ const int head_size = sizes[D_DIM];
371+ TORCH_CHECK (batch_size > 0 );
372+ TORCH_CHECK (head_size == 64 );
373+
374+ int seq_len = 512 ;
375+ auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
376+
377+ auto opts = qkv.options ();
378+
379+ auto dqkv = torch::empty_like (qkv);
380+
381+ int num_chunks = 2 ;
382+ if ( batch_size == 1 ) {
383+ num_chunks = 4 ;
384+ }else if ( batch_size == 2 ) {
385+ num_chunks = 3 ;
386+ }
387+ auto dkv = torch::empty ({total, num_chunks, 2 , num_heads, head_size}, opts);
388+
389+ Fused_multihead_attention_fprop_params params;
390+
391+ set_params (params,
392+ batch_size,
393+ seq_len,
394+ num_heads,
395+ head_size,
396+ qkv.data_ptr (),
397+ cu_seqlens.data_ptr (),
398+ dout.data_ptr (), // o_ptr = dout
399+ softmax.data_ptr (), // softmax gets overwritten by dP!
400+ p_dropout);
401+
402+ params.dkv_ptr = dkv.data_ptr ();
403+
404+ Data_type acc_type = DATA_TYPE_FP32;
405+ set_alpha (params.scale_bmm1 , 1 .f , acc_type);
406+ set_alpha (params.scale_softmax , 1 .f / sqrtf (head_size), acc_type);
407+ set_alpha (params.scale_bmm2 , 1 .f , DATA_TYPE_FP16);
408+ params.dqkv_ptr = dqkv.data_ptr ();
409+
410+ launch (params, num_chunks, stream);
411+
412+ // SPLIT-K reduction of num_chunks dK, dV parts
413+
414+ // The equivalent of the following Pytorch code:
415+ // using namespace torch::indexing;
416+ // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
417+ // torch::sum_out(view_out, dkv, 1);
418+
419+ const int hidden_size = num_heads * head_size;
420+ fmha_run_noloop_reduce (
421+ dqkv.data_ptr (), dkv.data_ptr (), cu_seqlens.data_ptr <int >(), hidden_size, batch_size, total, num_chunks, stream);
422+
423+ return { dqkv, softmax, dkv };
424+ }
425+
301426PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
302427 m.doc () = " Fused Multi-head Self-attention for BERT" ;
303428 m.def (" fwd" , &mha_fwd, " Forward pass" );
304429 m.def (" bwd" , &mha_bwd, " Backward pass" );
430+ m.def (" fwd_nl" , &mha_fwd_nl, " Forward pass (small-batch)" );
431+ m.def (" bwd_nl" , &mha_bwd_nl, " Backward pass (small-batch)" );
305432}
0 commit comments