Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b1c7600

Browse files
authored
[FMHA] Remove zero fill for softmax output (#1565)
* Remove zero fill for softmax output * Initial commit for mha_fill_kernel
1 parent be819ea commit b1c7600

3 files changed

Lines changed: 75 additions & 3 deletions

File tree

apex/contrib/csrc/fmha/fmha_api.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include "fmha.h"
3232

33+
extern at::Tensor & mha_fill(at::Tensor &self, const at::Tensor &start_index);
3334
void set_params(Fused_multihead_attention_fprop_params &params,
3435
// sizes
3536
const size_t b,
@@ -93,6 +94,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
9394
const bool zero_tensors,
9495
c10::optional<at::Generator> gen_) {
9596

97+
using namespace torch::indexing;
9698
auto dprops = at::cuda::getCurrentDeviceProperties();
9799
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
98100
(dprops->major == 9 && dprops->minor == 0));
@@ -143,8 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
143145
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
144146

145147
if( zero_tensors ) {
146-
ctx.zero_();
147-
s.zero_();
148+
mha_fill(ctx, cu_seqlens.index({Slice(-1,None)}));
148149
}
149150

150151
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
@@ -189,6 +190,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
189190
const int max_seq_len, // max sequence length to choose the kernel
190191
const bool zero_tensors
191192
) {
193+
using namespace torch::indexing;
192194
auto dprops = at::cuda::getCurrentDeviceProperties();
193195
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
194196
(dprops->major == 9 && dprops->minor == 0));
@@ -239,7 +241,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
239241
auto dqkv = torch::empty_like(qkv);
240242

241243
if( zero_tensors ) {
242-
dqkv.zero_();
244+
mha_fill(dqkv, cu_seqlens.index({Slice(-1,None)}));
243245
}
244246

245247
Fused_multihead_attention_fprop_params params;
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/******************************************************************************
2+
* Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are met:
6+
* * Redistributions of source code must retain the above copyright
7+
* notice, this list of conditions and the following disclaimer.
8+
* * Redistributions in binary form must reproduce the above copyright
9+
* notice, this list of conditions and the following disclaimer in the
10+
* documentation and/or other materials provided with the distribution.
11+
* * Neither the name of the NVIDIA CORPORATION nor the
12+
* names of its contributors may be used to endorse or promote products
13+
* derived from this software without specific prior written permission.
14+
*
15+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
*
26+
******************************************************************************/
27+
28+
#include <torch/extension.h>
29+
#include <ATen/cuda/CUDAContext.h>
30+
#include <ATen/Dispatch.h>
31+
32+
constexpr int block_size = 512;
33+
constexpr int ctas_per_sm = 4;
34+
35+
template <typename scalar_t>
36+
__global__ void
37+
__launch_bounds__(block_size)
38+
mha_fill_kernel(scalar_t* out_tensor,
39+
const int32_t* const start_row,
40+
const size_t num_rows) {
41+
size_t row_stride = gridDim.y * blockDim.x;
42+
size_t row_index = blockIdx.x + (size_t)start_row[0];
43+
size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;
44+
while (row_index < num_rows) {
45+
out_tensor[row_index*row_stride + col_index] = 0;
46+
row_index += gridDim.x;
47+
}
48+
}
49+
50+
at::Tensor & mha_fill(at::Tensor &self, const at::Tensor &start_index) {
51+
auto max_tokens = self.size(0);
52+
auto self_2d = self.view({max_tokens, -1});
53+
auto fcd_size = self_2d.size(1);
54+
TORCH_CHECK (self.is_contiguous(), "input not contiguous");
55+
TORCH_CHECK (fcd_size % block_size == 0, "input size not aligned to block size");
56+
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
57+
uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);
58+
uint64_t num_blk_x = (uint64_t)std::ceil(num_mp * ctas_per_sm / num_blk_y);
59+
dim3 dim_grid(num_blk_x, num_blk_y);
60+
dim3 dim_block(block_size);
61+
62+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
63+
at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_padding_fill_", [&]() {
64+
mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
65+
self_2d.data_ptr<scalar_t>(), start_index.data_ptr<int32_t>(), max_tokens);
66+
C10_CUDA_KERNEL_LAUNCH_CHECK();
67+
});
68+
return self;
69+
}

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
554554
name="fmhalib",
555555
sources=[
556556
"apex/contrib/csrc/fmha/fmha_api.cpp",
557+
"apex/contrib/csrc/fmha/src/fmha_fill.cu",
557558
"apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu",
558559
"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu",
559560
"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu",

0 commit comments

Comments
 (0)