Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3fe10b5

Browse files
authored
Seryilmaz/fused dropout softmax (#985)
* fuse dropout into softmax in fprop for additive mask case
1 parent 6c186b3 commit 3fe10b5

9 files changed

Lines changed: 1019 additions & 140 deletions

apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,15 @@ torch::Tensor bwd_cuda(
113113

114114
// Apply Dropout Mask and Scale by Dropout Probability
115115
// Softmax Grad
116-
dispatch_masked_scale_softmax_backward<half, half, float,false>(
116+
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
117117
static_cast<half*>(output_grads.data_ptr()),
118118
static_cast<half*>(output_grads.data_ptr()),
119119
reinterpret_cast<half const*>(softmax_results.data_ptr()),
120120
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
121121
1.0/(1.0-dropout_prob),
122122
k_seq_len,
123123
k_seq_len,
124-
attn_batches*q_seq_len);
124+
attn_batches*q_seq_len, stream);
125125
//backward pass is completely in-place
126126
return output_grads;
127127
}

apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,17 @@ torch::Tensor bwd_cuda(
115115
// Apply Dropout Mask and Scale by Dropout Probability
116116
// Softmax Grad
117117
if (padding_mask == nullptr) {
118-
dispatch_masked_scale_softmax_backward<half, half, float,false>(
118+
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
119119
static_cast<half*>(output_grads.data_ptr()),
120120
static_cast<half*>(output_grads.data_ptr()),
121121
reinterpret_cast<half const*>(softmax_results.data_ptr()),
122122
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
123123
1.0/(1.0-dropout_prob),
124124
k_seq_len,
125125
k_seq_len,
126-
attn_batches*q_seq_len);
126+
attn_batches*q_seq_len, stream);
127127
} else{
128-
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
128+
dispatch_masked_scale_softmax_backward_masked_out_stream<half, half, float,false>(
129129
static_cast<half*>(output_grads.data_ptr()),
130130
static_cast<half*>(output_grads.data_ptr()),
131131
reinterpret_cast<half const*>(softmax_results.data_ptr()),
@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
135135
k_seq_len,
136136
k_seq_len,
137137
attn_batches*q_seq_len,
138-
heads);
138+
heads, stream);
139139

140140
}
141141
//backward pass is completely in-place
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#pragma once
2+
//Philox CUDA.
3+
4+
class Philox {
5+
public:
6+
__device__ inline Philox(unsigned long long seed,
7+
unsigned long long subsequence,
8+
unsigned long long offset) {
9+
key.x = (unsigned int)seed;
10+
key.y = (unsigned int)(seed >> 32);
11+
counter = make_uint4(0, 0, 0, 0);
12+
counter.z = (unsigned int)(subsequence);
13+
counter.w = (unsigned int)(subsequence >> 32);
14+
STATE = 0;
15+
incr_n(offset / 4);
16+
}
17+
__device__ inline uint4 operator()() {
18+
if(STATE == 0) {
19+
uint4 counter_ = counter;
20+
uint2 key_ = key;
21+
//7-round philox
22+
for(int i = 0; i < 6; i++) {
23+
counter_ = single_round(counter_, key_);
24+
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
25+
}
26+
output = single_round(counter_, key_);
27+
incr();
28+
}
29+
//return a float4 directly
30+
//unsigned long ret;
31+
//switch(STATE) {
32+
// case 0: ret = output.x; break;
33+
// case 1: ret = output.y; break;
34+
// case 2: ret = output.z; break;
35+
// case 3: ret = output.w; break;
36+
//}
37+
//STATE = (STATE + 1) % 4;
38+
return output;
39+
}
40+
private:
41+
uint4 counter;
42+
uint4 output;
43+
uint2 key;
44+
unsigned int STATE;
45+
__device__ inline void incr_n(unsigned long long n) {
46+
unsigned int nlo = (unsigned int)(n);
47+
unsigned int nhi = (unsigned int)(n >> 32);
48+
counter.x += nlo;
49+
if (counter.x < nlo)
50+
nhi++;
51+
counter.y += nhi;
52+
if (nhi <= counter.y)
53+
return;
54+
if (++counter.z)
55+
return;
56+
++counter.w;
57+
}
58+
__device__ inline void incr() {
59+
if (++counter.x)
60+
return;
61+
if (++counter.y)
62+
return;
63+
if (++counter.z)
64+
return;
65+
++counter.w;
66+
}
67+
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
68+
unsigned int *result_high) {
69+
*result_high = __umulhi(a, b);
70+
return a*b;
71+
}
72+
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
73+
unsigned int hi0;
74+
unsigned int hi1;
75+
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
76+
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
77+
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
78+
return ret;
79+
}
80+
static const unsigned long kPhilox10A = 0x9E3779B9;
81+
static const unsigned long kPhilox10B = 0xBB67AE85;
82+
static const unsigned long kPhiloxSA = 0xD2511F53;
83+
static const unsigned long kPhiloxSB = 0xCD9E8D57;
84+
};
85+
// Inverse of 2^32.
86+
#define M_RAN_INVM32 2.3283064e-10f
87+
__device__ __inline__ float4 uniform4(uint4 x) {
88+
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32);
89+
90+
}

apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
2424
torch::Tensor const& output_grads,
2525
torch::Tensor const& matmul2_results,
2626
torch::Tensor const& dropout_results,
27-
torch::Tensor const& softmax_results,
27+
// torch::Tensor const& softmax_results,
28+
torch::Tensor const& bmm1_results,
29+
torch::Tensor const& pad_mask,
2830
torch::Tensor const& input_lin_results,
2931
torch::Tensor const& inputs,
3032
torch::Tensor const& input_weights,
@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
6062
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
6163
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
6264
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
65+
AT_ASSERTM(use_mask , "no mask is not supported");
6366

6467
if (use_mask) {
6568
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
8588
torch::Tensor const& output_grads,
8689
torch::Tensor const& matmul2_results,
8790
torch::Tensor const& dropout_results,
88-
torch::Tensor const& softmax_results,
91+
torch::Tensor const& bmm1_results,
92+
torch::Tensor const& pad_mask,
8993
torch::Tensor const& input_lin_results,
9094
torch::Tensor const& inputs,
9195
torch::Tensor const& input_weights,
@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
97101
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
98102
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
99103
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
100-
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
101104
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
102105
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
103106
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
107110
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
108111
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
109112
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
110-
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
111113
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
112114
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
113115
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
119121
output_grads,
120122
matmul2_results,
121123
dropout_results,
122-
softmax_results,
124+
bmm1_results,
125+
pad_mask,
123126
input_lin_results,
124127
inputs,
125128
input_weights,

apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
6363
auto mask_options = act_options.dtype(torch::kUInt8);
6464

6565
torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options);
66-
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
66+
torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
6767
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
6868
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
6969
torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);
@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
7575
void* v_lin_results_ptr = static_cast<void*>(static_cast<half*>(input_lin_results.data_ptr()) + 2*head_dim);
7676

7777
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
78-
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
78+
void* bmm1_results_ptr = static_cast<void*>(bmm1_results.data_ptr());
79+
void* dropout_results_ptr = static_cast<void*>(dropout_results.data_ptr());
7980

8081
char a_layout_t{'t'};
8182
char a_layout_n{'n'};
@@ -119,38 +120,36 @@ std::vector<torch::Tensor> fwd_cuda(
119120
lead_dim,
120121
batch_stride,
121122
beta_zero,
122-
static_cast<half*>(softmax_results_ptr),
123+
static_cast<half*>(bmm1_results_ptr),
123124
k_seq_len,
124125
k_seq_len*q_seq_len,
125126
attn_batches);
126127
// Padded Softmax
127128
bool softmax_success = false;
128-
if (pad_mask == nullptr) {
129-
softmax_success = dispatch_softmax<half, half, float>(
130-
reinterpret_cast<half*>(softmax_results_ptr),
131-
reinterpret_cast<const half*>(softmax_results_ptr),
132-
k_seq_len,
133-
k_seq_len,
134-
attn_batches*q_seq_len);
129+
if (is_training) {
130+
softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(
131+
reinterpret_cast<half*>(dropout_results_ptr),
132+
(is_training) ? reinterpret_cast<uint8_t*>(dropout_mask.data_ptr<uint8_t>()) : nullptr,
133+
reinterpret_cast<const half*>(bmm1_results_ptr),
134+
pad_mask,
135+
attn_batches*q_seq_len*q_seq_len,
136+
k_seq_len,
137+
k_seq_len,
138+
attn_batches*q_seq_len,
139+
attn_batches*q_seq_len/sequences,
140+
1.0f-dropout_prob,
141+
stream);
135142
} else {
136143
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
137-
reinterpret_cast<half*>(softmax_results_ptr),
138-
reinterpret_cast<const half*>(softmax_results_ptr),
144+
reinterpret_cast<half*>(dropout_results_ptr),//this is actually softmax results, but making it consistent for the next function
145+
reinterpret_cast<const half*>(bmm1_results_ptr),
139146
pad_mask,
140147
k_seq_len,
141148
k_seq_len,
142149
attn_batches*q_seq_len,
143150
attn_batches*q_seq_len/sequences);
144151
}
145152

146-
147-
if (is_training) {
148-
//use at:: function so that C++ version generates the same random mask as python version
149-
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
150-
dropout_results = std::get<0>(dropout_tuple);
151-
dropout_mask = std::get<1>(dropout_tuple);
152-
}
153-
154153
// Matmul2
155154
gemm_switch_fp32accum( state,
156155
a_layout_n,
@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
162161
static_cast<const half*>(v_lin_results_ptr),
163162
lead_dim,
164163
batch_stride,
165-
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
164+
static_cast<const half*>(dropout_results.data_ptr()),
166165
k_seq_len,
167166
k_seq_len*q_seq_len,
168167
beta_zero,
@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
199198

200199
return {
201200
input_lin_results,
202-
softmax_results,
201+
bmm1_results,
203202
dropout_results,
204203
dropout_mask,
205204
matmul2_results,
@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
212211
torch::Tensor const& output_grads,
213212
torch::Tensor const& matmul2_results,
214213
torch::Tensor const& dropout_results,
215-
torch::Tensor const& softmax_results,
214+
torch::Tensor const& bmm1_results,
215+
torch::Tensor const& pad_mask,
216216
torch::Tensor const& input_lin_results,
217217
torch::Tensor const& inputs,
218218
torch::Tensor const& input_weights,
@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
350350

351351
// Apply Dropout Mask and Scale by Dropout Probability
352352
// Softmax Grad
353-
dispatch_masked_scale_softmax_backward<half, half, float,false>(
353+
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
354354
static_cast<half*>(matmul2_grads.data_ptr()),
355-
static_cast<half*>(matmul2_grads.data_ptr()),
356-
reinterpret_cast<half const*>(softmax_results.data_ptr()),
355+
static_cast<half* const>(matmul2_grads.data_ptr()),
356+
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
357+
reinterpret_cast<half const*>(pad_mask.data_ptr()),
357358
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
358359
1.0/(1.0-dropout_prob),
359360
k_seq_len,
360361
k_seq_len,
361-
attn_batches*q_seq_len);
362+
attn_batches*q_seq_len/sequences,
363+
attn_batches*q_seq_len,
364+
stream);
362365

363366
// Matmul1 Dgrad1
364367
gemm_switch_fp32accum( state,

apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,15 @@ std::vector<torch::Tensor> bwd_cuda(
361361

362362
// Apply Dropout Mask and Scale by Dropout Probability
363363
// Softmax Grad
364-
dispatch_masked_scale_softmax_backward<half, half, float,false>(
364+
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
365365
static_cast<half*>(matmul2_grads.data_ptr()),
366366
static_cast<half*>(matmul2_grads.data_ptr()),
367367
reinterpret_cast<half const*>(softmax_results.data_ptr()),
368368
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
369369
1.0/(1.0-dropout_prob),
370370
k_seq_len,
371371
k_seq_len,
372-
attn_batches*q_seq_len);
372+
attn_batches*q_seq_len, stream);
373373

374374
// Matmul1 Dgrad1
375375
gemm_switch_fp32accum( state,

0 commit comments

Comments
 (0)