@@ -2844,7 +2844,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
28442844 uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
28452845 uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
28462846
2847- const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
2847+ const bool s = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
28482848
28492849 for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
28502850 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
@@ -2895,8 +2895,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
28952895 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
28962896 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
28972897 }
2898+ }
28982899
28992900#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2901+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
29002902 if (device->integer_dot_product) {
29012903 const uint32_t subgroup_size = (device->subgroup_size_control && device->vendor_id == VK_VENDOR_ID_INTEL) ? device->subgroup_min_size : device->subgroup_size;
29022904 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
@@ -2913,8 +2915,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29132915 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_q8_1_f32_len, mul_mat_vec_q8_0_q8_1_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {subgroup_size, 1*rm_stdq, i+1}, 1, true);
29142916 }
29152917 }
2916- #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
29172918 }
2919+ #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
29182920
29192921 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
29202922 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -5690,7 +5692,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56905692
56915693 if (dryrun) {
56925694 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
5693- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
5695+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
5696+ if (quantize_y) {
5697+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
5698+ }
56945699 const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
56955700 if (
56965701 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -5702,7 +5707,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57025707 ctx->prealloc_size_x = x_sz_upd;
57035708 }
57045709 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5705- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
5710+ ctx->prealloc_size_y = y_sz_upd;
57065711 }
57075712 if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
57085713 ctx->prealloc_size_split_k = split_k_size;
@@ -5756,7 +5761,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
57565761 GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
57575762 } else if (quantize_y) {
57585763 d_Y = ctx->prealloc_y;
5759- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1) );
5764+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144 );
57605765 } else {
57615766 d_Y = d_Qy;
57625767 y_buf_offset = qy_buf_offset;
@@ -5809,10 +5814,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
58095814 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
58105815 }
58115816
5817+ uint32_t y_sz_total = y_sz * ne12 * ne13;
5818+ if (quantize_y) {
5819+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
5820+ }
5821+
58125822 // compute
58135823 ggml_vk_matmul(
58145824 ctx, subctx, pipeline,
5815- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
5825+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
58165826 { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
58175827 ne01, ne11, ne10,
58185828 ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
@@ -5930,7 +5940,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59305940
59315941 if (dryrun) {
59325942 const uint64_t x_sz_upd = x_sz * ne02 * ne03;
5933- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
5943+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
5944+ if (quantize_y) {
5945+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
5946+ }
59345947 if (
59355948 (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
59365949 (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
@@ -5940,7 +5953,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59405953 ctx->prealloc_size_x = x_sz_upd;
59415954 }
59425955 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5943- ctx->prealloc_size_y = CEIL_DIV( y_sz_upd, 128) * 128 ;
5956+ ctx->prealloc_size_y = y_sz_upd;
59445957 }
59455958
59465959 // Request descriptor sets
@@ -5985,7 +5998,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
59855998 d_Y = ctx->prealloc_y;
59865999 } else if (quantize_y) {
59876000 d_Y = ctx->prealloc_y;
5988- GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128 ) * 128 );
6001+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144 ) * 144 );
59896002 } else {
59906003 d_Y = d_Qy;
59916004 y_buf_offset = qy_buf_offset;
@@ -6043,14 +6056,20 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
60436056 groups_x = CEIL_DIV(groups_x, groups_z);
60446057 }
60456058
6059+ // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
6060+ uint32_t y_sz_total = y_sz * ne12 * ne13;
6061+ if (quantize_y) {
6062+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
6063+ }
6064+
60466065 // compute
60476066 const vk_mat_vec_push_constants pc = {
60486067 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
60496068 stride_batch_x, stride_batch_y, stride_batch_d,
60506069 (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
60516070 };
60526071 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
6053- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
6072+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
60546073 pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
60556074
60566075 if (x_non_contig) {
0 commit comments