Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a5feacb

Browse files
[SDPA] [MPS] Fixes regression in 2.8.0 for scaled_dot_product_attention using mps (#164364)
[SDPA] [MPS] Fixes regression in 2.8.0 for scaled_dot_product_attention using mps (#163598) Fixes #163597 - Updates fast SDPA implementations to take in query tensor stride info similar to key and value instead of assuming stride. - Updated tests with additional transpose/permutation layouts. New tests catch the regression. ### Benchmarking with script found in [implementation PR](#152781) Times are averaged over 100000 iterations. This change should not have any significant performance difference. Tested on an M3 Pro ### Vector Fast Path (q_len=1, k_len=256) - Before: 0.160 ms - After: 0.157 ms ### Vector 2-pass (q_len=1, k_len=4096) - Before: 0.342 ms - After: 0.339 ms ### Vector Fast Path (q_len=8, k_len=256) - Before: 0.228 ms - After: 0.231 ms ### Vector 2-pass (q_len=8, k_len=4096) - Before: 0.432 ms - After: 0.436 ms Pull Request resolved: #163598 Approved by: https://github.com/malfet (cherry picked from commit 1c12d74) Co-authored-by: Vismai Khanderao <[email protected]>
1 parent 71282c8 commit a5feacb

File tree

3 files changed

+82
-55
lines changed

3 files changed

+82
-55
lines changed

aten/src/ATen/native/mps/kernels/Attention.metal

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ template <typename T, int D, int V = D>
1414
device T* out [[buffer(3)]],
1515
const constant uint& gqa_factor [[buffer(4)]],
1616
const constant uint& N [[buffer(5)]],
17-
const constant uint2& k_head_seq_stride [[buffer(6)]],
18-
const constant uint2& v_head_seq_stride [[buffer(7)]],
17+
const constant uint3& qkv_head_strides [[buffer(6)]],
18+
const constant uint3& qkv_seq_strides [[buffer(7)]],
1919
const constant float& scale [[buffer(8)]],
2020
const device bool* mask [[buffer(9)]],
2121
const constant uint3& mask_strides [[buffer(10)]],
@@ -28,10 +28,12 @@ template <typename T, int D, int V = D>
2828
constexpr uint BD = 32;
2929
constexpr uint qk_per_thread = D / BD;
3030
constexpr uint v_per_thread = V / BD;
31-
const uint k_head_stride = k_head_seq_stride.x;
32-
const uint k_seq_stride = k_head_seq_stride.y;
33-
const uint v_head_stride = v_head_seq_stride.x;
34-
const uint v_seq_stride = v_head_seq_stride.y;
31+
const uint q_head_stride = qkv_head_strides.x;
32+
const uint q_seq_stride = qkv_seq_strides.x;
33+
const uint k_head_stride = qkv_head_strides.y;
34+
const uint k_seq_stride = qkv_seq_strides.y;
35+
const uint v_head_stride = qkv_head_strides.z;
36+
const uint v_seq_stride = qkv_seq_strides.z;
3537
const uint mask_head_stride = mask_strides.x;
3638
const uint mask_kv_seq_stride = mask_strides.y;
3739
const uint mask_q_seq_stride = mask_strides.z;
@@ -54,9 +56,9 @@ template <typename T, int D, int V = D>
5456
const int kv_head_idx = head_idx / gqa_factor;
5557
const int Q = tpg.y;
5658
const int group_offset = head_idx * Q + q_seq_idx;
57-
const int q_offset = group_offset;
5859
const int o_offset = group_offset;
59-
queries += q_offset * D + simd_lid * qk_per_thread;
60+
queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride +
61+
simd_lid * qk_per_thread;
6062
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
6163
simd_lid * qk_per_thread;
6264
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
@@ -156,8 +158,8 @@ template <typename T, int D, int V = D>
156158
device float* maxs [[buffer(5)]],
157159
const constant uint& gqa_factor [[buffer(6)]],
158160
const constant uint& N [[buffer(7)]],
159-
const constant uint2& k_head_seq_stride [[buffer(8)]],
160-
const constant uint2& v_head_seq_stride [[buffer(9)]],
161+
const constant uint3& qkv_head_strides [[buffer(8)]],
162+
const constant uint3& qkv_seq_strides [[buffer(9)]],
161163
const constant float& scale [[buffer(10)]],
162164
const device bool* mask [[buffer(11)]],
163165
const constant uint3& mask_strides [[buffer(12)]],
@@ -170,10 +172,12 @@ template <typename T, int D, int V = D>
170172
constexpr int BD = 32;
171173
constexpr int qk_per_thread = D / BD;
172174
constexpr int v_per_thread = V / BD;
173-
const int k_head_stride = k_head_seq_stride.x;
174-
const int k_seq_stride = k_head_seq_stride.y;
175-
const int v_head_stride = v_head_seq_stride.x;
176-
const int v_seq_stride = v_head_seq_stride.y;
175+
const int q_head_stride = qkv_head_strides.x;
176+
const int q_seq_stride = qkv_seq_strides.x;
177+
const int k_head_stride = qkv_head_strides.y;
178+
const int k_seq_stride = qkv_seq_strides.y;
179+
const int v_head_stride = qkv_head_strides.z;
180+
const int v_seq_stride = qkv_seq_strides.z;
177181
const int mask_kv_seq_stride = mask_strides.x;
178182
const int mask_q_seq_stride = mask_strides.y;
179183
const int mask_head_stride = mask_strides.z;
@@ -196,10 +200,10 @@ template <typename T, int D, int V = D>
196200
const int head_idx = tid.x;
197201
const int q_seq_idx = tid.y;
198202
const int o_offset = head_idx * tpg.y + q_seq_idx;
199-
const int q_offset = o_offset;
200203
const int kv_head_idx = head_idx / gqa_factor;
201204

202-
queries += q_offset * D + simd_lid * qk_per_thread;
205+
queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride +
206+
simd_lid * qk_per_thread;
203207
keys += kv_head_idx * k_head_stride +
204208
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
205209
values += kv_head_idx * v_head_stride +
@@ -520,25 +524,25 @@ kernel void attention(
520524
}
521525
}
522526

523-
#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \
524-
template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \
525-
"_" #VALUE_DIM)]] kernel void \
526-
sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \
527-
const device DTYPE* queries [[buffer(0)]], \
528-
const device DTYPE* keys [[buffer(1)]], \
529-
const device DTYPE* values [[buffer(2)]], \
530-
device DTYPE* out [[buffer(3)]], \
531-
const constant uint& gqa_factor [[buffer(4)]], \
532-
const constant uint& N [[buffer(5)]], \
533-
const constant uint2& k_head_seq_stride [[buffer(6)]], \
534-
const constant uint2& v_head_seq_stride [[buffer(7)]], \
535-
const constant float& scale [[buffer(8)]], \
536-
const device bool* mask [[buffer(9)]], \
537-
const constant uint3& mask_strides [[buffer(10)]], \
538-
const constant bool& has_mask [[buffer(11)]], \
539-
uint3 tid [[threadgroup_position_in_grid]], \
540-
uint3 tpg [[threadgroups_per_grid]], \
541-
uint simd_gid [[simdgroup_index_in_threadgroup]], \
527+
#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \
528+
template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \
529+
"_" #VALUE_DIM)]] kernel void \
530+
sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \
531+
const device DTYPE* queries [[buffer(0)]], \
532+
const device DTYPE* keys [[buffer(1)]], \
533+
const device DTYPE* values [[buffer(2)]], \
534+
device DTYPE* out [[buffer(3)]], \
535+
const constant uint& gqa_factor [[buffer(4)]], \
536+
const constant uint& N [[buffer(5)]], \
537+
const constant uint3& qkv_head_strides [[buffer(6)]], \
538+
const constant uint3& qkv_seq_strides [[buffer(7)]], \
539+
const constant float& scale [[buffer(8)]], \
540+
const device bool* mask [[buffer(9)]], \
541+
const constant uint3& mask_strides [[buffer(10)]], \
542+
const constant bool& has_mask [[buffer(11)]], \
543+
uint3 tid [[threadgroup_position_in_grid]], \
544+
uint3 tpg [[threadgroups_per_grid]], \
545+
uint simd_gid [[simdgroup_index_in_threadgroup]], \
542546
uint simd_lid [[thread_index_in_simdgroup]]);
543547

544548
#define INSTANTIATE_SDPA_VECTOR_2PASS_1(DTYPE, QK_DIM, VALUE_DIM) \
@@ -553,8 +557,8 @@ kernel void attention(
553557
device float* maxs [[buffer(5)]], \
554558
const constant uint& gqa_factor [[buffer(6)]], \
555559
const constant uint& N [[buffer(7)]], \
556-
const constant uint2& k_head_seq_stride [[buffer(8)]], \
557-
const constant uint2& v_head_seq_stride [[buffer(9)]], \
560+
const constant uint3& qkv_head_strides [[buffer(8)]], \
561+
const constant uint3& qkv_seq_strides [[buffer(9)]], \
558562
const constant float& scale [[buffer(10)]], \
559563
const device bool* mask [[buffer(11)]], \
560564
const constant uint3& mask_strides [[buffer(12)]], \

aten/src/ATen/native/mps/operations/Attention.mm

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@
182182
uint maxSeqLength = k_.size(2);
183183
uint N = k_.size(2);
184184
uint B = q_.size(0) * q_.size(1);
185+
uint q_head_stride = q_.stride(1);
186+
uint q_seq_stride = q_.stride(2);
185187
uint k_head_stride = k_.stride(1);
186188
uint k_seq_stride = k_.stride(2);
187189
uint v_head_stride = v_.stride(1);
@@ -209,8 +211,8 @@
209211
out,
210212
1,
211213
N,
212-
std::array<uint32_t, 2>{k_head_stride, k_seq_stride},
213-
std::array<uint32_t, 2>{v_head_stride, v_seq_stride},
214+
std::array<uint32_t, 3>{q_head_stride, k_head_stride, v_head_stride},
215+
std::array<uint32_t, 3>{q_seq_stride, k_seq_stride, v_seq_stride},
214216
scale_factor);
215217

216218
if (has_mask) {
@@ -257,6 +259,8 @@
257259
uint B = batchSize * num_heads;
258260
uint gqa_factor = q_.size(1) / k_.size(1);
259261

262+
uint q_head_stride = q_.stride(1);
263+
uint q_seq_stride = q_.stride(2);
260264
uint k_head_stride = k_.stride(1);
261265
uint k_seq_stride = k_.stride(2);
262266
uint v_head_stride = v_.stride(1);
@@ -294,8 +298,8 @@
294298
maxs,
295299
gqa_factor,
296300
N,
297-
std::array<uint32_t, 2>{k_head_stride, k_seq_stride},
298-
std::array<uint32_t, 2>{v_head_stride, v_seq_stride},
301+
std::array<uint32_t, 3>{q_head_stride, k_head_stride, v_head_stride},
302+
std::array<uint32_t, 3>{q_seq_stride, k_seq_stride, v_seq_stride},
299303
scale_factor);
300304

301305
if (has_mask) {

test/test_mps.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9472,17 +9472,37 @@ def get_mps_memory_usage():
94729472
# 5 MB different maximum allowed value(could be decreased even more)
94739473
torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1)
94749474

9475-
def generate_qkv(self, batch, NH, q_len, s_len, head_dim, contiguous, dtype):
9476-
if contiguous:
9475+
def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype):
9476+
if layout == "contiguous":
94779477
q = torch.randn(batch, NH, q_len, head_dim, dtype=dtype, device="mps")
94789478
k = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps")
9479-
else:
9479+
elif layout == "mT":
9480+
# Transpose head dimension and length
94809481
q = torch.randn(batch, NH, head_dim, q_len, dtype=dtype, device="mps").mT
94819482
k = torch.randn(batch, NH, head_dim, s_len, dtype=dtype, device="mps").mT
9483+
elif layout == "transpose_seq_head":
9484+
# Transpose length and number of heads
9485+
q = torch.randn(batch, q_len, NH, head_dim, dtype=dtype, device="mps").transpose(1, 2)
9486+
k = torch.randn(batch, s_len, NH, head_dim, dtype=dtype, device="mps").transpose(1, 2)
9487+
elif layout == "permute":
9488+
# Permute head dimension and length
9489+
q = torch.randn(batch, head_dim, NH, q_len, dtype=dtype, device="mps").permute(0, 2, 3, 1)
9490+
k = torch.randn(batch, head_dim, NH, s_len, dtype=dtype, device="mps").permute(0, 2, 3, 1)
9491+
else:
9492+
raise ValueError(f"Unknown layout: {layout}")
9493+
94829494
v = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps")
94839495
return q, k, v
94849496

9485-
def run_fast_attention_test(self, q, k, v, with_mask, dropout_p=0.0, is_causal=False):
9497+
def run_fast_attention_test(
9498+
self,
9499+
q: torch.Tensor,
9500+
k: torch.Tensor,
9501+
v: torch.Tensor,
9502+
with_mask: bool,
9503+
dropout_p: float = 0.0,
9504+
is_causal: bool = False,
9505+
):
94869506
q_len = q.shape[2]
94879507
s_len = k.shape[2]
94889508

@@ -9523,48 +9543,47 @@ def run_fast_attention_test(self, q, k, v, with_mask, dropout_p=0.0, is_causal=F
95239543
self._compare_tensors(y.cpu(), y_ref)
95249544

95259545
@parametrize("dtype", [torch.float16, torch.float32])
9526-
@parametrize("contiguous", [True, False])
9546+
@parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"])
95279547
@parametrize("head_dim", [64, 96, 128]) # 64, 96, 128 are for the fast kernel
95289548
@parametrize("with_mask", [True, False])
9529-
def test_fast_vector_attention(self, dtype, contiguous, head_dim, with_mask):
9549+
def test_fast_vector_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool):
95309550
torch.manual_seed(1729)
95319551
batch = 1
95329552
NH = 2
95339553
q_len = 4 # <8 so that vector fast is eligible
95349554
s_len = 16 # smaller than 1024 so that we use the one–pass variant
9535-
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype)
9555+
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
95369556
self.run_fast_attention_test(q, k, v, with_mask)
95379557

95389558
@parametrize("dtype", [torch.float32]) # float16 underflows sometimes, which leads to flaky tests
9539-
@parametrize("contiguous", [True, False])
9559+
@parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"])
95409560
@parametrize("with_mask", [True, False])
9541-
def test_fast_vector_attention_2pass(self, dtype, contiguous, with_mask):
9561+
def test_fast_vector_attention_2pass(self, dtype: torch.dtype, layout: str, with_mask: bool):
95429562
torch.manual_seed(1729)
95439563
batch = 1
95449564
NH = 32
95459565
q_len = 8
95469566
s_len = 1024 # large enough to trigger the two–pass path
95479567
head_dim = 64 # supported head dimension for vector attention
9548-
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype)
9568+
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
95499569
self.run_fast_attention_test(q, k, v, with_mask)
95509570

95519571
@unittest.skip("Full attention fast kernel not implemented yet")
95529572
@parametrize("dtype", [torch.float16, torch.float32])
9553-
@parametrize("contiguous", [True, False])
9573+
@parametrize("layout", ["contiguous", "mT"])
95549574
@parametrize("head_dim", [64, 80, 128]) # 64, 80, 128 are for the fast kernel
95559575
@parametrize("with_mask", [True, False])
9556-
def test_fast_full_attention(self, dtype, contiguous, head_dim, with_mask):
9576+
def test_fast_full_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool):
95579577
torch.manual_seed(1729)
95589578
batch = 1
95599579
NH = 2
95609580
q_len = 32 # threshold to trigger full fast attention path
95619581
s_len = 16
9562-
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype)
9582+
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
95639583
self.run_fast_attention_test(q, k, v, with_mask)
95649584

95659585

95669586

9567-
95689587
class TestSDPAMetaDispatchMode(TorchDispatchMode):
95699588
"""
95709589
TorchDispatchMode which intercepts the

0 commit comments

Comments
 (0)