58
58
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
59
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
60
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
61
62
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
63
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
63
64
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
83
84
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
85
GGML_METAL_KERNEL_TYPE_NORM,
85
86
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
87
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
90
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
86
91
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
92
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
88
93
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
131
136
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132
137
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133
138
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
139
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
134
140
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135
141
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136
142
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
194
200
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
201
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
202
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203
+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
197
204
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
198
206
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199
207
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200
208
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -514,6 +522,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
514
522
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515
523
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516
524
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
525
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true);
517
526
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
518
527
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519
528
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
@@ -539,6 +548,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
539
548
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540
549
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541
550
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
551
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, ctx->support_simdgroup_reduction);
552
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction);
553
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, ctx->support_simdgroup_reduction);
554
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, ctx->support_simdgroup_reduction);
542
555
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543
556
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
544
557
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
@@ -587,6 +600,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
587
600
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
588
601
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
589
602
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
603
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm);
590
604
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
591
605
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
592
606
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
@@ -649,8 +663,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
649
663
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650
664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
651
665
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
652
667
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
668
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
669
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
654
670
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
655
671
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
656
672
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -736,8 +752,13 @@ static void ggml_metal_free(struct ggml_metal_context * ctx) {
736
752
737
753
static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738
754
for (size_t i = 0, n = 3; i < n; ++i) {
739
- if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
740
- return false;
755
+ if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
756
+ op->op != GGML_OP_GET_ROWS &&
757
+ op->op != GGML_OP_MUL_MAT &&
758
+ op->op != GGML_OP_VIEW &&
759
+ op->op != GGML_OP_CPY) {
760
+ printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
761
+ GGML_ASSERT(false);
741
762
}
742
763
}
743
764
@@ -811,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
811
832
case GGML_TYPE_F32:
812
833
switch (op->type) {
813
834
case GGML_TYPE_F32:
835
+ case GGML_TYPE_BF16:
814
836
case GGML_TYPE_F16:
815
837
case GGML_TYPE_Q8_0:
816
838
case GGML_TYPE_Q4_0:
@@ -830,6 +852,14 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
830
852
default:
831
853
return false;
832
854
}
855
+ case GGML_TYPE_BF16:
856
+ switch (op->type) {
857
+ case GGML_TYPE_F32:
858
+ case GGML_TYPE_F16:
859
+ return true;
860
+ default:
861
+ return false;
862
+ }
833
863
default:
834
864
return false;
835
865
};
@@ -1581,6 +1611,7 @@ static enum ggml_status ggml_metal_graph_compute(
1581
1611
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
1612
switch (src0->type) {
1583
1613
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1614
+ case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1584
1615
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1585
1616
default: break;
1586
1617
}
@@ -1589,6 +1620,7 @@ static enum ggml_status ggml_metal_graph_compute(
1589
1620
1590
1621
switch (src0->type) {
1591
1622
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1623
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1592
1624
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1593
1625
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1594
1626
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
@@ -1665,6 +1697,25 @@ static enum ggml_status ggml_metal_graph_compute(
1665
1697
nrows = 4;
1666
1698
}
1667
1699
} break;
1700
+ case GGML_TYPE_BF16:
1701
+ {
1702
+ nth0 = 32;
1703
+ nth1 = 1;
1704
+ if (src1t == GGML_TYPE_F32) {
1705
+ if (ne11 * ne12 < 4) {
1706
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
1707
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1708
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
1709
+ nrows = ne11;
1710
+ } else {
1711
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
1712
+ nrows = 4;
1713
+ }
1714
+ } else {
1715
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
1716
+ nrows = 4;
1717
+ }
1718
+ } break;
1668
1719
case GGML_TYPE_Q4_0:
1669
1720
{
1670
1721
nth0 = 8;
@@ -2161,6 +2212,7 @@ static enum ggml_status ggml_metal_graph_compute(
2161
2212
2162
2213
switch (src0->type) {
2163
2214
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2215
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
2164
2216
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2165
2217
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2166
2218
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
@@ -2776,6 +2828,7 @@ static enum ggml_status ggml_metal_graph_compute(
2776
2828
2777
2829
switch (dstt) {
2778
2830
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2831
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
2779
2832
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2780
2833
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2781
2834
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
@@ -2794,6 +2847,13 @@ static enum ggml_status ggml_metal_graph_compute(
2794
2847
default: GGML_ASSERT(false && "not implemented");
2795
2848
};
2796
2849
} break;
2850
+ case GGML_TYPE_BF16:
2851
+ {
2852
+ switch (dstt) {
2853
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
2854
+ default: GGML_ASSERT(false && "not implemented");
2855
+ };
2856
+ } break;
2797
2857
default: GGML_ASSERT(false && "not implemented");
2798
2858
}
2799
2859
0 commit comments