Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a255aa6

Browse files
authored
fmha: add sm_100 and sm_120 (#1890)
1 parent 6445b06 commit a255aa6

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

apex/contrib/csrc/fmha/fmha_api.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
9797
using namespace torch::indexing;
9898
auto dprops = at::cuda::getCurrentDeviceProperties();
9999
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
100-
(dprops->major == 9 && dprops->minor == 0));
100+
(dprops->major == 9 && dprops->minor == 0) ||
101+
(dprops->major == 10 && dprops->minor == 0) ||
102+
(dprops->major == 12 && dprops->minor == 0));
101103
auto stream = at::cuda::getCurrentCUDAStream().stream();
102104
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
103105

@@ -193,7 +195,9 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
193195
using namespace torch::indexing;
194196
auto dprops = at::cuda::getCurrentDeviceProperties();
195197
TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) ||
196-
(dprops->major == 9 && dprops->minor == 0));
198+
(dprops->major == 9 && dprops->minor == 0) ||
199+
(dprops->major == 10 && dprops->minor == 0) ||
200+
(dprops->major == 12 && dprops->minor == 0));
197201
int seq_len = 512;
198202
auto launch = &run_fmha_dgrad_fp16_512_64_sm80;
199203
if( max_seq_len <= 128 ) {

apex/contrib/test/fmha/test_fmha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def py_mha(qkv, amask, b, s, h, d):
6161

6262

6363
@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
64-
@unittest.skipIf(not _get_device_properties() == (8, 0), "FMHA only supports sm80")
64+
@unittest.skipIf(_get_device_properties() not in [(8, 0), (9, 0), (10, 0), (12, 0)], "FMHA only supports sm80")
6565
class TestFMHA(unittest.TestCase):
6666

6767
def run_test(self, s: int, b: int, zero_tensors: bool):

0 commit comments

Comments
 (0)