@@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4(
1598
1598
}
1599
1599
}
1600
1600
1601
+ kernel void kernel_mul_mv_f16_f32_l4_large (
1602
+ device const char * src0,
1603
+ device const char * src1,
1604
+ device float * dst,
1605
+ constant int64_t & ne00,
1606
+ constant int64_t & ne01,
1607
+ constant int64_t & ne02,
1608
+ constant uint64_t & nb00,
1609
+ constant uint64_t & nb01,
1610
+ constant uint64_t & nb02,
1611
+ constant int64_t & ne10,
1612
+ constant int64_t & ne11,
1613
+ constant int64_t & ne12,
1614
+ constant uint64_t & nb10,
1615
+ constant uint64_t & nb11,
1616
+ constant uint64_t & nb12,
1617
+ constant int64_t & ne0,
1618
+ constant int64_t & ne1,
1619
+ constant uint & r2,
1620
+ constant uint & r3,
1621
+ uint3 tgpig[[threadgroup_position_in_grid]],
1622
+ uint tiisg[[thread_index_in_simdgroup]]) {
1623
+
1624
+ const int nrows = ne11;
1625
+ const int64_t base_r0 = tgpig.x *32 ;
1626
+ const int64_t im = tgpig.z ;
1627
+ threadgroup float partial_sums[32 ]; // Shared memory for partial sums for each SIMD group
1628
+
1629
+ const uint i12 = im%ne12;
1630
+ const uint i13 = im/ne12;
1631
+
1632
+ for (int j = 0 ; j < 32 ; ++j) {
1633
+ const int64_t r0 = base_r0 + j;
1634
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1635
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1636
+
1637
+ partial_sums[tiisg] = 0 .0f ;
1638
+ for (int r1 = 0 ; r1 < nrows; ++r1) {
1639
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1640
+
1641
+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
1642
+ for (int k = 0 ; k < 4 ; ++k) partial_sums[tiisg] += (float ) x4[i][k] * y4[i][k];
1643
+ }
1644
+
1645
+ // Barrier to ensure all threads have written their partial sums
1646
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1647
+ float sumf = simd_sum (partial_sums[tiisg]);
1648
+ // Barrier to ensure reduction is complete before writing the result
1649
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1650
+
1651
+ if (tiisg == 0 ) {
1652
+ dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
1653
+ }
1654
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1655
+ }
1656
+ }
1657
+ }
1658
+
1601
1659
static float rope_yarn_ramp (const float low, const float high, const int i0) {
1602
1660
const float y = (i0 / 2 - low) / max (0 .001f , high - low);
1603
1661
return 1 .0f - min (1 .0f , max (0 .0f , y));
0 commit comments