Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1203099

Browse files
authored
Remove THCState from apex/contrib/multihead_attn (#1239)
* pass `self.mask_additive` * clang-format * removing THCState
1 parent 3c8f516 commit 1203099

22 files changed

Lines changed: 6534 additions & 7181 deletions
Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,74 @@
1-
#include <torch/extension.h>
21
#include <cuda_fp16.h>
2+
#include <torch/extension.h>
33
#include <vector>
44

55
namespace multihead_attn {
66
namespace fused_softmax {
77
namespace additive_mask_softmax_dropout {
88

9-
std::vector<torch::Tensor> fwd_cuda(
10-
bool is_training,
11-
int heads,
12-
torch::Tensor const& input,
13-
const half* pad_mask,
14-
float dropout_prob
15-
);
9+
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
10+
torch::Tensor const &input,
11+
const half *pad_mask, float dropout_prob);
1612

17-
torch::Tensor bwd_cuda(
18-
int heads,
19-
torch::Tensor const& output_grads,
20-
torch::Tensor const& softmax_results,
21-
torch::Tensor const& dropout_mask,
22-
float dropout_prob
23-
);
13+
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
14+
torch::Tensor const &softmax_results,
15+
torch::Tensor const &dropout_mask, float dropout_prob);
2416

2517
// C++ interface
2618

27-
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
28-
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
29-
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
30-
31-
std::vector<torch::Tensor> fwd(
32-
bool use_mask,
33-
bool is_training,
34-
int heads,
35-
torch::Tensor const& input,
36-
torch::Tensor const& pad_mask,
37-
float dropout_prob
38-
)
39-
{
40-
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
41-
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
19+
#define CHECK_CUDA(x) \
20+
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
21+
#define CHECK_CONTIGUOUS(x) \
22+
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
23+
#define CHECK_INPUT(x) \
24+
CHECK_CUDA(x); \
25+
CHECK_CONTIGUOUS(x)
4226

27+
std::vector<torch::Tensor> fwd(bool use_mask, bool is_training, int heads,
28+
torch::Tensor const &input,
29+
torch::Tensor const &pad_mask,
30+
float dropout_prob) {
31+
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
32+
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half,
33+
"Only HALF is supported");
4334
if (use_mask) {
44-
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
45-
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
35+
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
36+
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half,
37+
"Only BYTE is supported");
4638
}
4739

48-
return fwd_cuda(
49-
is_training,
50-
heads,
51-
input,
52-
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
53-
dropout_prob
54-
);
40+
return fwd_cuda(is_training, heads, input,
41+
use_mask ? static_cast<const half *>(pad_mask.data_ptr())
42+
: nullptr,
43+
dropout_prob);
5544
}
5645

57-
torch::Tensor bwd(
58-
bool use_mask,
59-
int heads,
60-
torch::Tensor const& output_grads,
61-
torch::Tensor const& softmax_results,
62-
torch::Tensor const& dropout_mask,
63-
float dropout_prob
64-
)
65-
{
66-
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
67-
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
68-
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
46+
torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads,
47+
torch::Tensor const &softmax_results,
48+
torch::Tensor const &dropout_mask, float dropout_prob) {
49+
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
50+
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
51+
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
52+
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half,
53+
"Only HALF is supported");
54+
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half,
55+
"Only HALF is supported");
56+
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte,
57+
// "Only BYTE is supported");
6958

70-
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
71-
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
72-
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
73-
74-
return bwd_cuda(
75-
heads,
76-
output_grads,
77-
softmax_results,
78-
dropout_mask,
79-
dropout_prob
80-
);
59+
return bwd_cuda(heads, output_grads, softmax_results, dropout_mask,
60+
dropout_prob);
8161
}
8262

83-
} // end namespace mask_softmax_dropout
63+
} // namespace additive_mask_softmax_dropout
8464
} // end namespace fused_softmax
8565
} // end namespace multihead_attn
8666

8767
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
88-
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
89-
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
68+
m.def("forward",
69+
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd,
70+
"Self Multihead Attention masked softmax dropout -- Forward.");
71+
m.def("backward",
72+
&multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd,
73+
"Self Multihead Attention masked softmax dropout -- Backward.");
9074
}
91-
Lines changed: 60 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,113 @@
1-
#include <vector>
2-
#include <math.h>
31
#include <iostream>
2+
#include <math.h>
3+
#include <vector>
44

55
#include <cuda.h>
6-
#include <cuda_runtime.h>
76
#include <cuda_fp16.h>
87
#include <cuda_profiler_api.h>
8+
#include <cuda_runtime.h>
99

1010
#include <ATen/ATen.h>
1111
#include <ATen/cuda/CUDAContext.h>
1212
#include <torch/extension.h>
1313

14-
#include "softmax.h"
1514
#include "dropout.h"
15+
#include "softmax.h"
1616

1717
// symbol to be automatically resolved by PyTorch libs
18-
extern THCState *state;
1918

2019
namespace multihead_attn {
2120
namespace fused_softmax {
2221
namespace additive_mask_softmax_dropout {
2322

24-
std::vector<torch::Tensor> fwd_cuda(
25-
bool is_training,
26-
int heads,
27-
torch::Tensor const& input,
28-
const half* pad_mask,
29-
float dropout_prob
30-
)
31-
{
32-
const int attn_batches = input.size(0);
33-
const int sequences = attn_batches / heads;
34-
const int q_seq_len = input.size(1);
35-
const int k_seq_len = q_seq_len;
36-
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
37-
38-
// There is no reason to use more than one stream as every kernel is
23+
std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads,
24+
torch::Tensor const &input,
25+
const half *pad_mask, float dropout_prob) {
26+
const int attn_batches = input.size(0);
27+
const int sequences = attn_batches / heads;
28+
const int q_seq_len = input.size(1);
29+
const int k_seq_len = q_seq_len;
30+
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
31+
32+
// There is no reason to use more than one stream as every kernel is
3933
// sequentially dependent
4034
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
41-
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
35+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
4236
cublasSetStream(handle, stream);
4337

44-
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
45-
auto act_options = input.options().requires_grad(false);
38+
// 3 Intermediate Results + Output (Note: dropout intermediates are generated
39+
// by ATen library code)
40+
auto act_options = input.options().requires_grad(false);
4641
auto mask_options = act_options.dtype(torch::kUInt8);
4742

48-
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
49-
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
50-
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
43+
torch::Tensor softmax_results =
44+
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
45+
torch::Tensor dropout_results =
46+
torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
47+
torch::Tensor dropout_mask =
48+
torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
5149

5250
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
53-
void* input_ptr = static_cast<void*>(input.data_ptr());
54-
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
51+
void *input_ptr = static_cast<void *>(input.data_ptr());
52+
void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr());
5553

5654
// Padded Softmax
5755
bool softmax_success = false;
5856
if (pad_mask == nullptr) {
5957
softmax_success = dispatch_softmax<half, half, float>(
60-
reinterpret_cast<half*>(softmax_results_ptr),
61-
reinterpret_cast<const half*>(input_ptr),
62-
k_seq_len,
63-
k_seq_len,
64-
attn_batches*q_seq_len);
58+
reinterpret_cast<half *>(softmax_results_ptr),
59+
reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len,
60+
attn_batches * q_seq_len);
6561
} else {
66-
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
67-
reinterpret_cast<half*>(softmax_results_ptr),
68-
reinterpret_cast<const half*>(input_ptr),
69-
pad_mask,
70-
k_seq_len,
71-
k_seq_len,
72-
attn_batches*q_seq_len,
73-
attn_batches*q_seq_len/sequences);
62+
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
63+
reinterpret_cast<half *>(softmax_results_ptr),
64+
reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len,
65+
k_seq_len, attn_batches * q_seq_len,
66+
attn_batches * q_seq_len / sequences);
7467
}
7568

76-
7769
if (is_training) {
78-
//use at:: function so that C++ version generates the same random mask as python version
79-
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
70+
// use at:: function so that C++ version generates the same random mask as
71+
// python version
72+
auto dropout_tuple =
73+
at::_fused_dropout(softmax_results, 1.0f - dropout_prob);
8074
dropout_results = std::get<0>(dropout_tuple);
8175
dropout_mask = std::get<1>(dropout_tuple);
8276
}
8377

8478
// Matmul2
8579

86-
return {
87-
dropout_results,
88-
dropout_mask,
89-
softmax_results
90-
};
80+
return {dropout_results, dropout_mask, softmax_results};
9181
}
9282

93-
torch::Tensor bwd_cuda(
94-
int heads,
95-
torch::Tensor const& output_grads,
96-
torch::Tensor const& softmax_results,
97-
torch::Tensor const& dropout_mask,
98-
float dropout_prob
99-
)
100-
{
101-
const int attn_batches = output_grads.size(0);
102-
const int q_seq_len = output_grads.size(1);
103-
const int k_seq_len = q_seq_len;
104-
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
83+
torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads,
84+
torch::Tensor const &softmax_results,
85+
torch::Tensor const &dropout_mask, float dropout_prob) {
86+
const int attn_batches = output_grads.size(0);
87+
const int q_seq_len = output_grads.size(1);
88+
const int k_seq_len = q_seq_len;
89+
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
10590
// TODO: Streams can be used in Backprop but I haven't added more than one
10691
// in my first attempt to create the code
10792
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
108-
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
93+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
10994
cublasSetStream(handle, stream);
11095

11196
// Output Tensor Allocations
112-
// torch::Tensor input_grads = torch::empty_like(output_grads);
97+
// torch::Tensor input_grads = torch::empty_like(output_grads);
11398

114-
// Apply Dropout Mask and Scale by Dropout Probability
99+
// Apply Dropout Mask and Scale by Dropout Probability
115100
// Softmax Grad
116-
dispatch_masked_scale_softmax_backward_stream<half, half, float,false>(
117-
static_cast<half*>(output_grads.data_ptr()),
118-
static_cast<half*>(output_grads.data_ptr()),
119-
reinterpret_cast<half const*>(softmax_results.data_ptr()),
120-
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
121-
1.0/(1.0-dropout_prob),
122-
k_seq_len,
123-
k_seq_len,
124-
attn_batches*q_seq_len, stream);
125-
//backward pass is completely in-place
101+
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
102+
static_cast<half *>(output_grads.data_ptr()),
103+
static_cast<half *>(output_grads.data_ptr()),
104+
reinterpret_cast<half const *>(softmax_results.data_ptr()),
105+
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
106+
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
107+
attn_batches * q_seq_len, stream);
108+
// backward pass is completely in-place
126109
return output_grads;
127110
}
128-
}
129-
}
130-
}
131-
111+
} // namespace additive_mask_softmax_dropout
112+
} // namespace fused_softmax
113+
} // namespace multihead_attn

0 commit comments

Comments
 (0)