From 16843dba332636774576022068b1b7da719e39e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 May 2025 09:13:52 +0300 Subject: [PATCH] metal : pad mm results --- ggml/src/ggml-metal/ggml-metal.m | 44 ++++++++++++++++++++-- ggml/src/ggml-metal/ggml-metal.metal | 56 ++++++++++++++-------------- 2 files changed, 68 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d92392edb7eb1..c44772475b630 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2443,7 +2443,7 @@ static bool ggml_metal_encode_node( #if 0 // cpy to tmp buffer in MTLHeap - id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); + id h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); if (!h_src0) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); return false; @@ -2947,6 +2947,12 @@ static bool ggml_metal_encode_node( default: break; } + id h_dst = ggml_metal_mem_pool_alloc(mem_pool, sizeof(float)*GGML_PAD(ne01, 64)*GGML_PAD(ne11, 32)*ne12*ne13); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); + return false; + } + id pipeline = nil; switch (src0->type) { @@ -2986,8 +2992,8 @@ static bool ggml_metal_encode_node( /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, + /*.ne0 =*/ GGML_PAD(ne01, 64), + /*.ne1 =*/ GGML_PAD(ne11, 32), /*.r2 =*/ r2, /*.r3 =*/ r3, }; @@ -2996,10 +3002,40 @@ static bool ggml_metal_encode_node( [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:h_dst offset:0 atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + ggml_metal_kargs_cpy args_cpy = { + /*.ne00 =*/ ne0, + /*.ne01 =*/ ne1, + /*.ne02 =*/ ne2, + /*.ne03 =*/ ne3, + /*.nb00 =*/ nb0, + /*.nb01 =*/ nb0*GGML_PAD(ne01, 64), + /*.nb02 =*/ nb0*GGML_PAD(ne01, 64)*GGML_PAD(ne11, 32), + /*.nb03 =*/ nb0*GGML_PAD(ne01, 64)*GGML_PAD(ne11, 32)*ne12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; + + [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; + [encoder setBuffer:h_dst offset:0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int nth_cpy = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; } else { id pipeline = nil; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f4147e93974d..401fc445956e3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6305,34 +6305,34 @@ kernel void kernel_mul_mm( } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; - device float4 * D4 = (device float4 *) D; - - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); - threadgroup float4 * C4 = (threadgroup float4 *) C; - - int i = 0; - for (; i < n_rows/4; i++) { - *(D4 + i) = *(C4 + i); - } - - i *= 4; - for (; i < n_rows; i++) { - *(D + i) = *(C + i); - } - } - } + //threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup float * temp_str = ((threadgroup float *) shmem) \ + // + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + //for (short i = 0; i < 8; i++) { + // simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + //} + + //threadgroup_barrier(mem_flags::mem_threadgroup); + + //if (sgitg == 0) { + // for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + // device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; + // device float4 * D4 = (device float4 *) D; + + // threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + // threadgroup float4 * C4 = (threadgroup float4 *) C; + + // int i = 0; + // for (; i < n_rows/4; i++) { + // *(D4 + i) = *(C4 + i); + // } + + // i *= 4; + // for (; i < n_rows; i++) { + // *(D + i) = *(C + i); + // } + // } + //} } }