Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c40e9df

Browse files
committed
ggml : add initial BF16 support
1 parent c917b67 commit c40e9df

File tree

3 files changed

+105
-24
lines changed

3 files changed

+105
-24
lines changed

common/common.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -2160,6 +2160,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
21602160
if (s == "f32") {
21612161
return GGML_TYPE_F32;
21622162
}
2163+
if (s == "bf16") {
2164+
return GGML_TYPE_BF16;
2165+
}
21632166
if (s == "f16") {
21642167
return GGML_TYPE_F16;
21652168
}

ggml/src/ggml-metal.m

+62-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
5959
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
6060
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61+
GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
6162
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
6263
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
6364
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
@@ -83,6 +84,10 @@
8384
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8485
GGML_METAL_KERNEL_TYPE_NORM,
8586
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
8691
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8792
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
8893
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -131,6 +136,7 @@
131136
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132137
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133138
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139+
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
134140
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135141
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136142
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
@@ -194,7 +200,9 @@
194200
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195201
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196202
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203+
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
197204
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205+
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
198206
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199207
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200208
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -514,6 +522,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
514522
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515523
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516524
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
525+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
517526
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
518527
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519528
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
@@ -539,6 +548,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
539548
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540549
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541550
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
551+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction);
552+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction);
553+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction);
554+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction);
542555
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543556
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
544557
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
@@ -587,6 +600,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
587600
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
588601
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
589602
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
603+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm);
590604
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
591605
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
592606
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
@@ -649,8 +663,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
649663
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
651665
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
652667
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653668
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
669+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
654670
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
655671
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
656672
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -736,8 +752,13 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
736752

737753
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738754
for (size_t i = 0, n = 3; i < n; ++i) {
739-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
740-
return false;
755+
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
756+
op->op != GGML_OP_GET_ROWS &&
757+
op->op != GGML_OP_MUL_MAT &&
758+
op->op != GGML_OP_VIEW &&
759+
op->op != GGML_OP_CPY) {
760+
printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
761+
GGML_ASSERT(false);
741762
}
742763
}
743764

@@ -811,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
811832
case GGML_TYPE_F32:
812833
switch (op->type) {
813834
case GGML_TYPE_F32:
835+
case GGML_TYPE_BF16:
814836
case GGML_TYPE_F16:
815837
case GGML_TYPE_Q8_0:
816838
case GGML_TYPE_Q4_0:
@@ -830,6 +852,14 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
830852
default:
831853
return false;
832854
}
855+
case GGML_TYPE_BF16:
856+
switch (op->type) {
857+
case GGML_TYPE_F32:
858+
case GGML_TYPE_F16:
859+
return true;
860+
default:
861+
return false;
862+
}
833863
default:
834864
return false;
835865
};
@@ -1581,6 +1611,7 @@ static enum ggml_status ggml_metal_graph_compute(
15811611
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
15821612
switch (src0->type) {
15831613
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1614+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
15841615
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
15851616
default: break;
15861617
}
@@ -1589,6 +1620,7 @@ static enum ggml_status ggml_metal_graph_compute(
15891620

15901621
switch (src0->type) {
15911622
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1623+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
15921624
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
15931625
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
15941626
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
@@ -1665,6 +1697,25 @@ static enum ggml_status ggml_metal_graph_compute(
16651697
nrows = 4;
16661698
}
16671699
} break;
1700+
case GGML_TYPE_BF16:
1701+
{
1702+
nth0 = 32;
1703+
nth1 = 1;
1704+
if (src1t == GGML_TYPE_F32) {
1705+
if (ne11 * ne12 < 4) {
1706+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1707+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1708+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1709+
nrows = ne11;
1710+
} else {
1711+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1712+
nrows = 4;
1713+
}
1714+
} else {
1715+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
1716+
nrows = 4;
1717+
}
1718+
} break;
16681719
case GGML_TYPE_Q4_0:
16691720
{
16701721
nth0 = 8;
@@ -2161,6 +2212,7 @@ static enum ggml_status ggml_metal_graph_compute(
21612212

21622213
switch (src0->type) {
21632214
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2215+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
21642216
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
21652217
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
21662218
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
@@ -2776,6 +2828,7 @@ static enum ggml_status ggml_metal_graph_compute(
27762828

27772829
switch (dstt) {
27782830
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2831+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
27792832
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
27802833
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
27812834
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
@@ -2794,6 +2847,13 @@ static enum ggml_status ggml_metal_graph_compute(
27942847
default: GGML_ASSERT(false && "not implemented");
27952848
};
27962849
} break;
2850+
case GGML_TYPE_BF16:
2851+
{
2852+
switch (dstt) {
2853+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
2854+
default: GGML_ASSERT(false && "not implemented");
2855+
};
2856+
} break;
27972857
default: GGML_ASSERT(false && "not implemented");
27982858
}
27992859

0 commit comments

Comments
 (0)