|
1 | | -#include <vector> |
2 | | -#include <math.h> |
3 | 1 | #include <iostream> |
| 2 | +#include <math.h> |
| 3 | +#include <vector> |
4 | 4 |
|
5 | 5 | #include <cuda.h> |
6 | | -#include <cuda_runtime.h> |
7 | 6 | #include <cuda_fp16.h> |
8 | 7 | #include <cuda_profiler_api.h> |
| 8 | +#include <cuda_runtime.h> |
9 | 9 |
|
10 | 10 | #include <ATen/ATen.h> |
11 | 11 | #include <ATen/cuda/CUDAContext.h> |
12 | 12 | #include <torch/extension.h> |
13 | 13 |
|
14 | | -#include "softmax.h" |
15 | 14 | #include "dropout.h" |
| 15 | +#include "softmax.h" |
16 | 16 |
|
17 | 17 | // symbol to be automatically resolved by PyTorch libs |
18 | | -extern THCState *state; |
19 | 18 |
|
20 | 19 | namespace multihead_attn { |
21 | 20 | namespace fused_softmax { |
22 | 21 | namespace additive_mask_softmax_dropout { |
23 | 22 |
|
24 | | -std::vector<torch::Tensor> fwd_cuda( |
25 | | - bool is_training, |
26 | | - int heads, |
27 | | - torch::Tensor const& input, |
28 | | - const half* pad_mask, |
29 | | - float dropout_prob |
30 | | - ) |
31 | | -{ |
32 | | - const int attn_batches = input.size(0); |
33 | | - const int sequences = attn_batches / heads; |
34 | | - const int q_seq_len = input.size(1); |
35 | | - const int k_seq_len = q_seq_len; |
36 | | - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; |
37 | | - |
38 | | - // There is no reason to use more than one stream as every kernel is |
| 23 | +std::vector<torch::Tensor> fwd_cuda(bool is_training, int heads, |
| 24 | + torch::Tensor const &input, |
| 25 | + const half *pad_mask, float dropout_prob) { |
| 26 | + const int attn_batches = input.size(0); |
| 27 | + const int sequences = attn_batches / heads; |
| 28 | + const int q_seq_len = input.size(1); |
| 29 | + const int k_seq_len = q_seq_len; |
| 30 | + const int dropout_elems = attn_batches * q_seq_len * k_seq_len; |
| 31 | + |
| 32 | + // There is no reason to use more than one stream as every kernel is |
39 | 33 | // sequentially dependent |
40 | 34 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
41 | | - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 35 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
42 | 36 | cublasSetStream(handle, stream); |
43 | 37 |
|
44 | | - // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code) |
45 | | - auto act_options = input.options().requires_grad(false); |
| 38 | + // 3 Intermediate Results + Output (Note: dropout intermediates are generated |
| 39 | + // by ATen library code) |
| 40 | + auto act_options = input.options().requires_grad(false); |
46 | 41 | auto mask_options = act_options.dtype(torch::kUInt8); |
47 | 42 |
|
48 | | - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); |
49 | | - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); |
50 | | - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); |
| 43 | + torch::Tensor softmax_results = |
| 44 | + torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); |
| 45 | + torch::Tensor dropout_results = |
| 46 | + torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); |
| 47 | + torch::Tensor dropout_mask = |
| 48 | + torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); |
51 | 49 |
|
52 | 50 | // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) |
53 | | - void* input_ptr = static_cast<void*>(input.data_ptr()); |
54 | | - void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); |
| 51 | + void *input_ptr = static_cast<void *>(input.data_ptr()); |
| 52 | + void *softmax_results_ptr = static_cast<void *>(softmax_results.data_ptr()); |
55 | 53 |
|
56 | 54 | // Padded Softmax |
57 | 55 | bool softmax_success = false; |
58 | 56 | if (pad_mask == nullptr) { |
59 | 57 | softmax_success = dispatch_softmax<half, half, float>( |
60 | | - reinterpret_cast<half*>(softmax_results_ptr), |
61 | | - reinterpret_cast<const half*>(input_ptr), |
62 | | - k_seq_len, |
63 | | - k_seq_len, |
64 | | - attn_batches*q_seq_len); |
| 58 | + reinterpret_cast<half *>(softmax_results_ptr), |
| 59 | + reinterpret_cast<const half *>(input_ptr), k_seq_len, k_seq_len, |
| 60 | + attn_batches * q_seq_len); |
65 | 61 | } else { |
66 | | - softmax_success = dispatch_additive_masked_softmax<half, half, float>( |
67 | | - reinterpret_cast<half*>(softmax_results_ptr), |
68 | | - reinterpret_cast<const half*>(input_ptr), |
69 | | - pad_mask, |
70 | | - k_seq_len, |
71 | | - k_seq_len, |
72 | | - attn_batches*q_seq_len, |
73 | | - attn_batches*q_seq_len/sequences); |
| 62 | + softmax_success = dispatch_additive_masked_softmax<half, half, float>( |
| 63 | + reinterpret_cast<half *>(softmax_results_ptr), |
| 64 | + reinterpret_cast<const half *>(input_ptr), pad_mask, k_seq_len, |
| 65 | + k_seq_len, attn_batches * q_seq_len, |
| 66 | + attn_batches * q_seq_len / sequences); |
74 | 67 | } |
75 | 68 |
|
76 | | - |
77 | 69 | if (is_training) { |
78 | | - //use at:: function so that C++ version generates the same random mask as python version |
79 | | - auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob); |
| 70 | + // use at:: function so that C++ version generates the same random mask as |
| 71 | + // python version |
| 72 | + auto dropout_tuple = |
| 73 | + at::_fused_dropout(softmax_results, 1.0f - dropout_prob); |
80 | 74 | dropout_results = std::get<0>(dropout_tuple); |
81 | 75 | dropout_mask = std::get<1>(dropout_tuple); |
82 | 76 | } |
83 | 77 |
|
84 | 78 | // Matmul2 |
85 | 79 |
|
86 | | - return { |
87 | | - dropout_results, |
88 | | - dropout_mask, |
89 | | - softmax_results |
90 | | - }; |
| 80 | + return {dropout_results, dropout_mask, softmax_results}; |
91 | 81 | } |
92 | 82 |
|
93 | | -torch::Tensor bwd_cuda( |
94 | | - int heads, |
95 | | - torch::Tensor const& output_grads, |
96 | | - torch::Tensor const& softmax_results, |
97 | | - torch::Tensor const& dropout_mask, |
98 | | - float dropout_prob |
99 | | - ) |
100 | | -{ |
101 | | - const int attn_batches = output_grads.size(0); |
102 | | - const int q_seq_len = output_grads.size(1); |
103 | | - const int k_seq_len = q_seq_len; |
104 | | - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; |
| 83 | +torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, |
| 84 | + torch::Tensor const &softmax_results, |
| 85 | + torch::Tensor const &dropout_mask, float dropout_prob) { |
| 86 | + const int attn_batches = output_grads.size(0); |
| 87 | + const int q_seq_len = output_grads.size(1); |
| 88 | + const int k_seq_len = q_seq_len; |
| 89 | + const int dropout_elems = attn_batches * q_seq_len * k_seq_len; |
105 | 90 | // TODO: Streams can be used in Backprop but I haven't added more than one |
106 | 91 | // in my first attempt to create the code |
107 | 92 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); |
108 | | - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 93 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
109 | 94 | cublasSetStream(handle, stream); |
110 | 95 |
|
111 | 96 | // Output Tensor Allocations |
112 | | -// torch::Tensor input_grads = torch::empty_like(output_grads); |
| 97 | + // torch::Tensor input_grads = torch::empty_like(output_grads); |
113 | 98 |
|
114 | | - // Apply Dropout Mask and Scale by Dropout Probability |
| 99 | + // Apply Dropout Mask and Scale by Dropout Probability |
115 | 100 | // Softmax Grad |
116 | | - dispatch_masked_scale_softmax_backward_stream<half, half, float,false>( |
117 | | - static_cast<half*>(output_grads.data_ptr()), |
118 | | - static_cast<half*>(output_grads.data_ptr()), |
119 | | - reinterpret_cast<half const*>(softmax_results.data_ptr()), |
120 | | - static_cast<uint8_t const*>(dropout_mask.data_ptr()), |
121 | | - 1.0/(1.0-dropout_prob), |
122 | | - k_seq_len, |
123 | | - k_seq_len, |
124 | | - attn_batches*q_seq_len, stream); |
125 | | -//backward pass is completely in-place |
| 101 | + dispatch_masked_scale_softmax_backward_stream<half, half, float, false>( |
| 102 | + static_cast<half *>(output_grads.data_ptr()), |
| 103 | + static_cast<half *>(output_grads.data_ptr()), |
| 104 | + reinterpret_cast<half const *>(softmax_results.data_ptr()), |
| 105 | + static_cast<uint8_t const *>(dropout_mask.data_ptr()), |
| 106 | + 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, |
| 107 | + attn_batches * q_seq_len, stream); |
| 108 | + // backward pass is completely in-place |
126 | 109 | return output_grads; |
127 | 110 | } |
128 | | -} |
129 | | -} |
130 | | -} |
131 | | - |
| 111 | +} // namespace additive_mask_softmax_dropout |
| 112 | +} // namespace fused_softmax |
| 113 | +} // namespace multihead_attn |
0 commit comments