Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7e1c22d

Browse files
chochowskicrcrpar
andauthored
contrib/fmha: Add option to zero out tensors before math (#1322)
* extend api to allow forced memory zeroing (empty() does not do it) * typo fix * ctx change * move zeroing flag to ctx * update test Co-authored-by: mchochowski <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]>
1 parent 44c3043 commit 7e1c22d

3 files changed

Lines changed: 54 additions & 24 deletions

File tree

apex/contrib/csrc/fmha/fmha_api.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
8989
const float p_dropout,
9090
const int max_seq_len,
9191
const bool is_training,
92+
const bool zero_tensors,
9293
c10::optional<at::Generator> gen_) {
9394
auto dprops = at::cuda::getCurrentDeviceProperties();
9495
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
@@ -147,6 +148,11 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
147148

148149
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
149150

151+
if( zero_tensors ) {
152+
ctx.zero_();
153+
s.zero_();
154+
}
155+
150156
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
151157
gen_, at::cuda::detail::getDefaultCUDAGenerator());
152158

@@ -185,7 +191,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
185191
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
186192
const at::Tensor &cu_seqlens, // b+1
187193
const float p_dropout, // probability to drop
188-
const int max_seq_len // max sequence length to choose the kernel
194+
const int max_seq_len, // max sequence length to choose the kernel
195+
const bool zero_tensors
189196
) {
190197
auto dprops = at::cuda::getCurrentDeviceProperties();
191198
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
@@ -235,6 +242,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
235242

236243
auto dqkv = torch::empty_like(qkv);
237244

245+
if( zero_tensors ) {
246+
dqkv.zero_();
247+
}
248+
238249
Fused_multihead_attention_fprop_params params;
239250

240251
set_params(params,
@@ -264,6 +275,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
264275
const float p_dropout,
265276
const int max_seq_len,
266277
const bool is_training,
278+
const bool zero_tensors,
267279
c10::optional<at::Generator> gen_) {
268280
int seq_len = 512;
269281
auto launch = &run_fmha_fp16_512_64_sm80_nl;
@@ -304,6 +316,11 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
304316

305317
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
306318

319+
if( zero_tensors ) {
320+
ctx.zero_();
321+
s.zero_();
322+
}
323+
307324
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
308325

309326
Fused_multihead_attention_fprop_params params;
@@ -344,7 +361,8 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
344361
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
345362
const at::Tensor &cu_seqlens, // b+1
346363
const float p_dropout, // probability to drop
347-
const int max_seq_len // max sequence length to choose the kernel
364+
const int max_seq_len, // max sequence length to choose the kernel
365+
const bool zero_tensors
348366
) {
349367

350368
auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -378,6 +396,10 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
378396

379397
auto dqkv = torch::empty_like(qkv);
380398

399+
if( zero_tensors ) {
400+
dqkv.zero_();
401+
}
402+
381403
int num_chunks = 2;
382404
if( batch_size == 1 ) {
383405
num_chunks = 4;

apex/contrib/fmha/fmha.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,29 @@
3232

3333
class FMHAFun(torch.autograd.Function):
3434
@staticmethod
35-
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
35+
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors):
3636
batch_size = cu_seqlens.numel() - 1
3737
if batch_size < 4:
38-
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
38+
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
3939
else:
40-
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
40+
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors, None)
4141
ctx.save_for_backward(qkv, S_dmask)
4242
ctx.cu_seqlens = cu_seqlens
4343
ctx.p_dropout = p_dropout
4444
ctx.max_s = max_s
45+
ctx.zero_tensors = zero_tensors
4546
return context
4647

4748
@staticmethod
4849
def backward(ctx, dout):
4950
qkv, S_dmask = ctx.saved_tensors
5051
batch_size = ctx.cu_seqlens.numel() - 1
5152
if batch_size < 4:
52-
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
53+
dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
5354
else:
54-
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
55+
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors)
5556

56-
return dqkv, None, None, None, None, None, None
57+
return dqkv, None, None, None, None, None, None, None
5758

5859
class FMHA(torch.nn.Module):
5960

@@ -67,8 +68,8 @@ def __init__(self, config):
6768
self.d = self.hidden_size // self.h
6869
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
6970

70-
def forward(self, qkv, cu_seqlens, max_s, is_training=True):
71+
def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False):
7172

72-
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training)
73+
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors)
7374

7475
return ctx.view(-1, self.hidden_size)

apex/contrib/test/fmha/test_fmha.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
5151

5252
class TestFMHA(unittest.TestCase):
5353

54-
def run_test(self, s, b):
55-
print(f'Test s={s} b={b}')
54+
def run_test(self, s: int, b: int, zero_tensors: bool):
55+
print(f'Test s={s} b={b}, zero_tensors={zero_tensors}')
5656

5757
torch.manual_seed(1234)
5858
torch.cuda.manual_seed(1234)
@@ -77,9 +77,9 @@ def run_test(self, s, b):
7777
qkv.requires_grad = True
7878

7979
if b < 4:
80-
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, None)
80+
ctx, S_ = mha.fwd_nl(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
8181
else:
82-
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, None)
82+
ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, zero_tensors, None)
8383
ctx = ctx.view(b,s,h,d)
8484

8585
ctx_ref = py_mha(qkv, amask, b,s,h,d)
@@ -95,27 +95,34 @@ def run_test(self, s, b):
9595
dw2 = dw.permute(0,2,1,3).clone().detach().contiguous()
9696

9797
if b < 4:
98-
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
98+
dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
9999
else:
100-
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s)
101-
100+
dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)
101+
102102
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
103-
103+
104104
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
105105

106106
def test_128(self):
107-
self.run_test(128, 32)
107+
self.run_test(128, 32, False)
108+
self.run_test(128, 32, True)
108109

109110
def test_256(self):
110-
self.run_test(256, 32)
111+
self.run_test(256, 32, False)
112+
self.run_test(256, 32, True)
111113

112114
def test_384(self):
113-
self.run_test(384, 32)
115+
self.run_test(384, 32, False)
116+
self.run_test(384, 32, True)
114117

115118
def test_512(self):
116-
self.run_test(512, 32)
117-
self.run_test(512, 2)
118-
self.run_test(512, 3)
119+
self.run_test(512, 32, False)
120+
self.run_test(512, 32, True)
121+
self.run_test(512, 2, False)
122+
self.run_test(512, 2, True)
123+
self.run_test(512, 3, False)
124+
self.run_test(512, 3, True)
125+
119126

120127
if __name__ == '__main__':
121128
unittest.main()

0 commit comments

Comments
 (0)