@@ -14,8 +14,8 @@ template <typename T, int D, int V = D>
1414 device T* out [[buffer(3 )]],
1515 const constant uint& gqa_factor [[buffer(4 )]],
1616 const constant uint& N [[buffer(5 )]],
17- const constant uint2& k_head_seq_stride [[buffer(6 )]],
18- const constant uint2& v_head_seq_stride [[buffer(7 )]],
17+ const constant uint3& qkv_head_strides [[buffer(6 )]],
18+ const constant uint3& qkv_seq_strides [[buffer(7 )]],
1919 const constant float& scale [[buffer(8 )]],
2020 const device bool* mask [[buffer(9 )]],
2121 const constant uint3& mask_strides [[buffer(10 )]],
@@ -28,10 +28,12 @@ template <typename T, int D, int V = D>
2828 constexpr uint BD = 32 ;
2929 constexpr uint qk_per_thread = D / BD;
3030 constexpr uint v_per_thread = V / BD;
31- const uint k_head_stride = k_head_seq_stride.x ;
32- const uint k_seq_stride = k_head_seq_stride.y ;
33- const uint v_head_stride = v_head_seq_stride.x ;
34- const uint v_seq_stride = v_head_seq_stride.y ;
31+ const uint q_head_stride = qkv_head_strides.x ;
32+ const uint q_seq_stride = qkv_seq_strides.x ;
33+ const uint k_head_stride = qkv_head_strides.y ;
34+ const uint k_seq_stride = qkv_seq_strides.y ;
35+ const uint v_head_stride = qkv_head_strides.z ;
36+ const uint v_seq_stride = qkv_seq_strides.z ;
3537 const uint mask_head_stride = mask_strides.x ;
3638 const uint mask_kv_seq_stride = mask_strides.y ;
3739 const uint mask_q_seq_stride = mask_strides.z ;
@@ -54,9 +56,9 @@ template <typename T, int D, int V = D>
5456 const int kv_head_idx = head_idx / gqa_factor;
5557 const int Q = tpg.y ;
5658 const int group_offset = head_idx * Q + q_seq_idx;
57- const int q_offset = group_offset;
5859 const int o_offset = group_offset;
59- queries += q_offset * D + simd_lid * qk_per_thread;
60+ queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride +
61+ simd_lid * qk_per_thread;
6062 keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
6163 simd_lid * qk_per_thread;
6264 values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
@@ -156,8 +158,8 @@ template <typename T, int D, int V = D>
156158 device float* maxs [[buffer(5 )]],
157159 const constant uint& gqa_factor [[buffer(6 )]],
158160 const constant uint& N [[buffer(7 )]],
159- const constant uint2& k_head_seq_stride [[buffer(8 )]],
160- const constant uint2& v_head_seq_stride [[buffer(9 )]],
161+ const constant uint3& qkv_head_strides [[buffer(8 )]],
162+ const constant uint3& qkv_seq_strides [[buffer(9 )]],
161163 const constant float& scale [[buffer(10 )]],
162164 const device bool* mask [[buffer(11 )]],
163165 const constant uint3& mask_strides [[buffer(12 )]],
@@ -170,10 +172,12 @@ template <typename T, int D, int V = D>
170172 constexpr int BD = 32 ;
171173 constexpr int qk_per_thread = D / BD;
172174 constexpr int v_per_thread = V / BD;
173- const int k_head_stride = k_head_seq_stride.x ;
174- const int k_seq_stride = k_head_seq_stride.y ;
175- const int v_head_stride = v_head_seq_stride.x ;
176- const int v_seq_stride = v_head_seq_stride.y ;
175+ const int q_head_stride = qkv_head_strides.x ;
176+ const int q_seq_stride = qkv_seq_strides.x ;
177+ const int k_head_stride = qkv_head_strides.y ;
178+ const int k_seq_stride = qkv_seq_strides.y ;
179+ const int v_head_stride = qkv_head_strides.z ;
180+ const int v_seq_stride = qkv_seq_strides.z ;
177181 const int mask_kv_seq_stride = mask_strides.x ;
178182 const int mask_q_seq_stride = mask_strides.y ;
179183 const int mask_head_stride = mask_strides.z ;
@@ -196,10 +200,10 @@ template <typename T, int D, int V = D>
196200 const int head_idx = tid.x ;
197201 const int q_seq_idx = tid.y ;
198202 const int o_offset = head_idx * tpg.y + q_seq_idx;
199- const int q_offset = o_offset;
200203 const int kv_head_idx = head_idx / gqa_factor;
201204
202- queries += q_offset * D + simd_lid * qk_per_thread;
205+ queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride +
206+ simd_lid * qk_per_thread;
203207 keys += kv_head_idx * k_head_stride +
204208 (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
205209 values += kv_head_idx * v_head_stride +
@@ -520,25 +524,25 @@ kernel void attention(
520524 }
521525}
522526
523- #define INSTANTIATE_SDPA_VECTOR (DTYPE, QK_DIM, VALUE_DIM ) \
524- template [[host_name(" sdpa_vector_" #DTYPE " _" #QK_DIM \
525- " _" #VALUE_DIM)]] kernel void \
526- sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \
527- const device DTYPE* queries [[buffer(0 )]], \
528- const device DTYPE* keys [[buffer(1 )]], \
529- const device DTYPE* values [[buffer(2 )]], \
530- device DTYPE* out [[buffer(3 )]], \
531- const constant uint& gqa_factor [[buffer(4 )]], \
532- const constant uint& N [[buffer(5 )]], \
533- const constant uint2& k_head_seq_stride [[buffer(6 )]], \
534- const constant uint2& v_head_seq_stride [[buffer(7 )]], \
535- const constant float & scale [[buffer(8 )]], \
536- const device bool * mask [[buffer(9 )]], \
537- const constant uint3& mask_strides [[buffer(10 )]], \
538- const constant bool & has_mask [[buffer(11 )]], \
539- uint3 tid [[threadgroup_position_in_grid]], \
540- uint3 tpg [[threadgroups_per_grid]], \
541- uint simd_gid [[simdgroup_index_in_threadgroup]], \
527+ #define INSTANTIATE_SDPA_VECTOR (DTYPE, QK_DIM, VALUE_DIM ) \
528+ template [[host_name(" sdpa_vector_" #DTYPE " _" #QK_DIM \
529+ " _" #VALUE_DIM)]] kernel void \
530+ sdpa_vector<DTYPE, QK_DIM, VALUE_DIM>( \
531+ const device DTYPE* queries [[buffer(0 )]], \
532+ const device DTYPE* keys [[buffer(1 )]], \
533+ const device DTYPE* values [[buffer(2 )]], \
534+ device DTYPE* out [[buffer(3 )]], \
535+ const constant uint& gqa_factor [[buffer(4 )]], \
536+ const constant uint& N [[buffer(5 )]], \
537+ const constant uint3& qkv_head_strides [[buffer(6 )]], \
538+ const constant uint3& qkv_seq_strides [[buffer(7 )]], \
539+ const constant float & scale [[buffer(8 )]], \
540+ const device bool * mask [[buffer(9 )]], \
541+ const constant uint3& mask_strides [[buffer(10 )]], \
542+ const constant bool & has_mask [[buffer(11 )]], \
543+ uint3 tid [[threadgroup_position_in_grid]], \
544+ uint3 tpg [[threadgroups_per_grid]], \
545+ uint simd_gid [[simdgroup_index_in_threadgroup]], \
542546 uint simd_lid [[thread_index_in_simdgroup]]);
543547
544548#define INSTANTIATE_SDPA_VECTOR_2PASS_1 (DTYPE, QK_DIM, VALUE_DIM ) \
@@ -553,8 +557,8 @@ kernel void attention(
553557 device float * maxs [[buffer(5 )]], \
554558 const constant uint& gqa_factor [[buffer(6 )]], \
555559 const constant uint& N [[buffer(7 )]], \
556- const constant uint2& k_head_seq_stride [[buffer(8 )]], \
557- const constant uint2& v_head_seq_stride [[buffer(9 )]], \
560+ const constant uint3& qkv_head_strides [[buffer(8 )]], \
561+ const constant uint3& qkv_seq_strides [[buffer(9 )]], \
558562 const constant float & scale [[buffer(10 )]], \
559563 const device bool * mask [[buffer(11 )]], \
560564 const constant uint3& mask_strides [[buffer(12 )]], \
0 commit comments