Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b3d55bc

Browse files
author
Alexander Komarov
authored
replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large
replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large for vectors larger than 128 elements.
1 parent cd2322c commit b3d55bc

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

ggml-metal.metal

+58
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4(
15981598
}
15991599
}
16001600

1601+
kernel void kernel_mul_mv_f16_f32_l4_large(
1602+
device const char * src0,
1603+
device const char * src1,
1604+
device float * dst,
1605+
constant int64_t & ne00,
1606+
constant int64_t & ne01,
1607+
constant int64_t & ne02,
1608+
constant uint64_t & nb00,
1609+
constant uint64_t & nb01,
1610+
constant uint64_t & nb02,
1611+
constant int64_t & ne10,
1612+
constant int64_t & ne11,
1613+
constant int64_t & ne12,
1614+
constant uint64_t & nb10,
1615+
constant uint64_t & nb11,
1616+
constant uint64_t & nb12,
1617+
constant int64_t & ne0,
1618+
constant int64_t & ne1,
1619+
constant uint & r2,
1620+
constant uint & r3,
1621+
uint3 tgpig[[threadgroup_position_in_grid]],
1622+
uint tiisg[[thread_index_in_simdgroup]]) {
1623+
1624+
const int nrows = ne11;
1625+
const int64_t base_r0 = tgpig.x*32;
1626+
const int64_t im = tgpig.z;
1627+
threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group
1628+
1629+
const uint i12 = im%ne12;
1630+
const uint i13 = im/ne12;
1631+
1632+
for (int j = 0; j < 32; ++j) {
1633+
const int64_t r0 = base_r0 + j;
1634+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1635+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
1636+
1637+
partial_sums[tiisg] = 0.0f;
1638+
for (int r1 = 0; r1 < nrows; ++r1) {
1639+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
1640+
1641+
for (int i = tiisg; i < ne00/4; i += 32) {
1642+
for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k];
1643+
}
1644+
1645+
// Barrier to ensure all threads have written their partial sums
1646+
threadgroup_barrier(mem_flags::mem_threadgroup);
1647+
float sumf = simd_sum(partial_sums[tiisg]);
1648+
// Barrier to ensure reduction is complete before writing the result
1649+
threadgroup_barrier(mem_flags::mem_threadgroup);
1650+
1651+
if (tiisg == 0) {
1652+
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
1653+
}
1654+
threadgroup_barrier(mem_flags::mem_threadgroup);
1655+
}
1656+
}
1657+
}
1658+
16011659
static float rope_yarn_ramp(const float low, const float high, const int i0) {
16021660
const float y = (i0 / 2 - low) / max(0.001f, high - low);
16031661
return 1.0f - min(1.0f, max(0.0f, y));

0 commit comments

Comments
 (0)