-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathgeneric_scaled_masked_softmax_cuda.cu
More file actions
89 lines (76 loc) · 3.84 KB
/
generic_scaled_masked_softmax_cuda.cu
File metadata and controls
89 lines (76 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/* coding=utf-8
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "generic_scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace generic_scaled_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward_new<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
auto act_options = output_grads.options();
torch::Tensor input_grad = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward_new<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(static_cast<void*>(input_grad.data_ptr())),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
// backward pass is completely in-place
return input_grad;
}
} // namespace generic_scaled_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn