Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cd0a1f1

Browse files
authored
Add a generic fused softmax (#1440)
* new kernel Signed-off-by: Yi Dong <[email protected]> * added the unit tests Signed-off-by: Yi Dong <[email protected]> * clean up unittest Signed-off-by: Yi Dong <[email protected]> * use float Signed-off-by: Yi Dong <[email protected]> * more clean up Signed-off-by: Yi Dong <[email protected]> * remove the long seq test case
1 parent 31cbdd1 commit cd0a1f1

6 files changed

Lines changed: 731 additions & 0 deletions

File tree

apex/transformer/functional/fused_softmax.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,33 @@ def scaled_masked_softmax(inputs, mask, scale):
9898
return ScaledMaskedSoftmax.apply(*args)
9999

100100

101+
class GenericScaledMaskedSoftmax(torch.autograd.Function):
102+
@staticmethod
103+
def forward(ctx, inputs, mask, scale):
104+
import generic_scaled_masked_softmax_cuda
105+
106+
scale_t = torch.tensor([scale])
107+
softmax_results = generic_scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
108+
ctx.save_for_backward(softmax_results, scale_t)
109+
return softmax_results
110+
111+
@staticmethod
112+
def backward(ctx, output_grads):
113+
import generic_scaled_masked_softmax_cuda_new
114+
115+
softmax_results, scale_t = ctx.saved_tensors
116+
117+
input_grads = generic_scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
118+
return input_grads, None, None
119+
120+
121+
def generic_scaled_masked_softmax(inputs, mask, scale):
122+
# input is 4D tensor (b, np, sq, sk)
123+
args = _cast_if_autocast_enabled(inputs, mask, scale)
124+
with torch.cuda.amp.autocast(enabled=False):
125+
return GenericScaledMaskedSoftmax.apply(*args)
126+
127+
101128
class FusedScaleMaskSoftmax(torch.nn.Module):
102129
"""
103130
fused operation: scaling + mask + softmax
@@ -209,3 +236,30 @@ def get_batch_per_block(sq, sk, b, np):
209236
import scaled_masked_softmax_cuda
210237

211238
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
239+
240+
class GenericFusedScaleMaskSoftmax(FusedScaleMaskSoftmax):
241+
"""
242+
Generic version of FusedSacleMaskSoftmax.
243+
It removes the seq-len limitations and has slight performance degragation compared with FusedScaleMaskSoftmax
244+
245+
fused operation: scaling + mask + softmax
246+
247+
Arguments:
248+
input_in_fp16: flag to indicate if input in fp16 data format.
249+
input_in_bf16: flag to indicate if input in bf16 data format.
250+
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
251+
mask_func: mask function to be applied.
252+
softmax_in_fp32: if true, softmax in performed at fp32 precision.
253+
scale: scaling factor used in input tensor scaling.
254+
"""
255+
256+
def __init__(
257+
self, input_in_fp16, input_in_bf16, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale,
258+
):
259+
super().__init__(input_in_fp16, input_in_bf16, AttnMaskType.padding, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale)
260+
self.scaled_masked_softmax_fusion = generic_scaled_masked_softmax
261+
262+
def is_kernel_available(self, mask, b, np, sq, sk):
263+
if self.scaled_masked_softmax_fusion and 0 < sk: # user want to fuse # sk must be 1 ~
264+
return True
265+
return False
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/* coding=utf-8
2+
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <cuda_fp16.h>
18+
#include <torch/extension.h>
19+
#include <vector>
20+
21+
namespace multihead_attn
22+
{
23+
namespace fused_softmax
24+
{
25+
namespace generic_scaled_masked_softmax
26+
{
27+
28+
torch::Tensor fwd_cuda(
29+
torch::Tensor const &input,
30+
torch::Tensor const &mask,
31+
float scale_factor);
32+
33+
torch::Tensor bwd_cuda(
34+
torch::Tensor const &output_grads,
35+
torch::Tensor const &softmax_results,
36+
float scale_factor);
37+
38+
torch::Tensor fwd(
39+
torch::Tensor const &input,
40+
torch::Tensor const &mask,
41+
float scale_factor)
42+
{
43+
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
44+
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
45+
(input.scalar_type() == at::ScalarType::BFloat16),
46+
"Only fp16 and bf16 are supported");
47+
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
48+
49+
return fwd_cuda(input, mask, scale_factor);
50+
}
51+
52+
torch::Tensor bwd(
53+
torch::Tensor const &output_grads,
54+
torch::Tensor const &softmax_results,
55+
float scale_factor)
56+
{
57+
58+
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
59+
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
60+
61+
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
62+
(output_grads.scalar_type() == at::ScalarType::BFloat16),
63+
"Only fp16 and bf16 are supported");
64+
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
65+
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
66+
"Only fp16 and bf16 are supported");
67+
68+
return bwd_cuda(output_grads, softmax_results, scale_factor);
69+
}
70+
71+
} // end namespace generic_scaled_masked_softmax
72+
} // end namespace fused_softmax
73+
} // end namespace multihead_attn
74+
75+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
76+
m.def("forward",
77+
&multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd,
78+
"Self Multihead Attention scaled, time masked softmax -- Forward.");
79+
80+
m.def("backward",
81+
&multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd,
82+
"Self Multihead Attention scaled, time masked softmax -- Backward.");
83+
}

0 commit comments

Comments
 (0)