@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
6363 auto mask_options = act_options.dtype (torch::kUInt8 );
6464
6565 torch::Tensor input_lin_results = torch::empty ({q_seq_len, sequences, output_lin_dim}, act_options);
66- torch::Tensor softmax_results = torch::empty ({attn_batches, q_seq_len, k_seq_len}, act_options);
66+ torch::Tensor bmm1_results = torch::empty ({attn_batches, q_seq_len, k_seq_len}, act_options);
6767 torch::Tensor dropout_results = torch::empty ({attn_batches, q_seq_len, k_seq_len}, act_options);
6868 torch::Tensor dropout_mask = torch::empty ({attn_batches, q_seq_len, k_seq_len}, mask_options);
6969 torch::Tensor matmul2_results = torch::empty ({q_seq_len, attn_batches, head_dim}, act_options);
@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
7575 void * v_lin_results_ptr = static_cast <void *>(static_cast <half*>(input_lin_results.data_ptr ()) + 2 *head_dim);
7676
7777 // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
78- void * softmax_results_ptr = static_cast <void *>(softmax_results.data_ptr ());
78+ void * bmm1_results_ptr = static_cast <void *>(bmm1_results.data_ptr ());
79+ void * dropout_results_ptr = static_cast <void *>(dropout_results.data_ptr ());
7980
8081 char a_layout_t {' t' };
8182 char a_layout_n{' n' };
@@ -119,38 +120,36 @@ std::vector<torch::Tensor> fwd_cuda(
119120 lead_dim,
120121 batch_stride,
121122 beta_zero,
122- static_cast <half*>(softmax_results_ptr ),
123+ static_cast <half*>(bmm1_results_ptr ),
123124 k_seq_len,
124125 k_seq_len*q_seq_len,
125126 attn_batches);
126127 // Padded Softmax
127128 bool softmax_success = false ;
128- if (pad_mask == nullptr ) {
129- softmax_success = dispatch_softmax<half, half, float >(
130- reinterpret_cast <half*>(softmax_results_ptr),
131- reinterpret_cast <const half*>(softmax_results_ptr),
132- k_seq_len,
133- k_seq_len,
134- attn_batches*q_seq_len);
129+ if (is_training) {
130+ softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float >(
131+ reinterpret_cast <half*>(dropout_results_ptr),
132+ (is_training) ? reinterpret_cast <uint8_t *>(dropout_mask.data_ptr <uint8_t >()) : nullptr ,
133+ reinterpret_cast <const half*>(bmm1_results_ptr),
134+ pad_mask,
135+ attn_batches*q_seq_len*q_seq_len,
136+ k_seq_len,
137+ k_seq_len,
138+ attn_batches*q_seq_len,
139+ attn_batches*q_seq_len/sequences,
140+ 1 .0f -dropout_prob,
141+ stream);
135142 } else {
136143 softmax_success = dispatch_additive_masked_softmax<half, half, float >(
137- reinterpret_cast <half*>(softmax_results_ptr),
138- reinterpret_cast <const half*>(softmax_results_ptr ),
144+ reinterpret_cast <half*>(dropout_results_ptr), // this is actually softmax results, but making it consistent for the next function
145+ reinterpret_cast <const half*>(bmm1_results_ptr ),
139146 pad_mask,
140147 k_seq_len,
141148 k_seq_len,
142149 attn_batches*q_seq_len,
143150 attn_batches*q_seq_len/sequences);
144151 }
145152
146-
147- if (is_training) {
148- // use at:: function so that C++ version generates the same random mask as python version
149- auto dropout_tuple = at::_fused_dropout (softmax_results, 1 .0f -dropout_prob);
150- dropout_results = std::get<0 >(dropout_tuple);
151- dropout_mask = std::get<1 >(dropout_tuple);
152- }
153-
154153 // Matmul2
155154 gemm_switch_fp32accum ( state,
156155 a_layout_n,
@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
162161 static_cast <const half*>(v_lin_results_ptr),
163162 lead_dim,
164163 batch_stride,
165- (is_training) ? static_cast <const half*>(dropout_results.data_ptr ()) : static_cast < const half*>(softmax_results. data_ptr ()) ,
164+ static_cast <const half*>(dropout_results.data_ptr ()),
166165 k_seq_len,
167166 k_seq_len*q_seq_len,
168167 beta_zero,
@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
199198
200199 return {
201200 input_lin_results,
202- softmax_results ,
201+ bmm1_results ,
203202 dropout_results,
204203 dropout_mask,
205204 matmul2_results,
@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
212211 torch::Tensor const & output_grads,
213212 torch::Tensor const & matmul2_results,
214213 torch::Tensor const & dropout_results,
215- torch::Tensor const & softmax_results,
214+ torch::Tensor const & bmm1_results,
215+ torch::Tensor const & pad_mask,
216216 torch::Tensor const & input_lin_results,
217217 torch::Tensor const & inputs,
218218 torch::Tensor const & input_weights,
@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
350350
351351 // Apply Dropout Mask and Scale by Dropout Probability
352352 // Softmax Grad
353- dispatch_masked_scale_softmax_backward <half, half, float ,false >(
353+ dispatch_masked_scale_softmax_backward_recompute <half, half, float , false >(
354354 static_cast <half*>(matmul2_grads.data_ptr ()),
355- static_cast <half*>(matmul2_grads.data_ptr ()),
356- reinterpret_cast <half const *>(softmax_results.data_ptr ()),
355+ static_cast <half* const >(matmul2_grads.data_ptr ()),
356+ reinterpret_cast <half const *>(bmm1_results.data_ptr ()),
357+ reinterpret_cast <half const *>(pad_mask.data_ptr ()),
357358 static_cast <uint8_t const *>(dropout_mask.data_ptr ()),
358359 1.0 /(1.0 -dropout_prob),
359360 k_seq_len,
360361 k_seq_len,
361- attn_batches*q_seq_len);
362+ attn_batches*q_seq_len/sequences,
363+ attn_batches*q_seq_len,
364+ stream);
362365
363366 // Matmul1 Dgrad1
364367 gemm_switch_fp32accum ( state,
0 commit comments