Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 810ffae

Browse files
authored
Update test_fused_softmax.py (#1782)
1 parent b496d85 commit 810ffae

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

tests/L0/run_transformer/test_fused_softmax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ def test_forward(self, allmasked: bool=False):
346346

347347
def test_backward(self, allmasked: bool=False):
348348
import generic_scaled_masked_softmax_cuda
349+
prev_thresh = self.thresh
350+
self.thresh = {"atol": 1.5e-1, "rtol": 5e-3}
349351
for qlen, klen in self.q_k_lens:
350352
inputs = torch.normal(0, 2, (self.batch, self.attn, qlen, klen), dtype=self.dtype, device=self.device)
351353
backward = torch.rand_like(inputs, dtype=torch.float16, device=self.device)
@@ -359,6 +361,7 @@ def test_backward(self, allmasked: bool=False):
359361
softmax_results_torch = forward_torch_softmax(inputs, masks, self.scale_t)
360362
softmax_results_torch.backward(backward)
361363
self.assertEqual(back_grad, inputs.grad, **self.thresh, msg=f"(q, k) = ({qlen, klen})")
364+
self.thresh = prev_thresh
362365

363366
def test_allmasked(self):
364367
self.test_forward(True)

0 commit comments

Comments
 (0)