@@ -2694,138 +2694,132 @@ catch (sycl::exception const &exc) {
2694
2694
std::exit (1 );
2695
2695
}
2696
2696
2697
- static void k_compute_batched_ptrs (const sycl::half *src0_as_f16,
2698
- const sycl::half *src1_as_f16, char *dst,
2699
- const void **ptrs_src, void **ptrs_dst,
2700
- int64_t ne12, int64_t ne13, int64_t ne23,
2701
- size_t nb02, size_t nb03, size_t nb12,
2702
- size_t nb13, size_t nbd2, size_t nbd3,
2703
- int64_t r2, int64_t r3,
2704
- const sycl::nd_item<3 > &item_ct1) {
2705
- int64_t i13 = item_ct1.get_group (2 ) * item_ct1.get_local_range (2 ) +
2706
- item_ct1.get_local_id (2 );
2707
- int64_t i12 = item_ct1.get_group (1 ) * item_ct1.get_local_range (1 ) +
2708
- item_ct1.get_local_id (1 );
2697
+ static void k_compute_batched_ptrs (const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
2698
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2699
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2700
+ int64_t r2, int64_t r3, const sycl::nd_item<3 > & item_ct1) {
2701
+ const int64_t i13 = item_ct1.get_group (2 ) * item_ct1.get_local_range (2 ) + item_ct1.get_local_id (2 );
2702
+ const int64_t i12 = item_ct1.get_group (1 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 );
2709
2703
2710
2704
if (i13 >= ne13 || i12 >= ne12) {
2711
2705
return ;
2712
2706
}
2713
2707
2714
- int64_t i03 = i13 / r3;
2715
- int64_t i02 = i12 / r2;
2708
+ const int64_t i03 = i13 / r3;
2709
+ const int64_t i02 = i12 / r2;
2710
+
2711
+ const uint8_t * src0_bytes = reinterpret_cast <const uint8_t *>(src0_as_f16);
2712
+ const uint8_t * src1_bytes = reinterpret_cast <const uint8_t *>(src1_as_f16);
2713
+ uint8_t * dst_bytes = reinterpret_cast <uint8_t *>(dst);
2716
2714
2717
- ptrs_src[0 * ne23 + i12 + i13* ne12] = ( const char *) src0_as_f16 + i02* nb02 + i03* nb03;
2718
- ptrs_src[1 * ne23 + i12 + i13* ne12] = ( const char *) src1_as_f16 + i12* nb12 + i13* nb13;
2719
- ptrs_dst[0 * ne23 + i12 + i13* ne12] = ( char *) dst + i12* nbd2 + i13* nbd3;
2715
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2716
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2717
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2720
2718
}
2721
2719
2722
- static void ggml_sycl_mul_mat_batched_sycl (ggml_backend_sycl_context & ctx,
2723
- const ggml_tensor *src0,
2724
- const ggml_tensor *src1,
2725
- ggml_tensor *dst) try {
2720
+ static void ggml_sycl_mul_mat_batched_sycl (ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2721
+ const ggml_tensor * src1, ggml_tensor * dst) try {
2726
2722
GGML_ASSERT (!ggml_is_transposed (src0));
2727
2723
GGML_ASSERT (!ggml_is_transposed (src1));
2728
2724
GGML_ASSERT (!ggml_backend_buffer_is_sycl_split (src0->buffer ));
2729
2725
GGML_ASSERT (src0->type == GGML_TYPE_F16);
2730
2726
2731
2727
GGML_TENSOR_BINARY_OP_LOCALS
2732
2728
2729
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2730
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2731
+ GGML_ASSERT (ggml_is_contiguous (dst));
2733
2732
2734
2733
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
2735
- queue_ptr main_stream = ctx.stream (); ;
2734
+ queue_ptr queue = ctx.stream ();
2736
2735
2737
- void * src0_ddq = src0->data ;
2738
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
2739
- float * src1_ddf = (float *) src1->data ;
2740
- float * dst_ddf = (float *) dst->data ;
2736
+ dpct::has_capability_or_fail (queue->get_device (), { sycl::aspect::fp16 });
2741
2737
2742
- // convert src1 to fp16
2738
+ const sycl::half * src0_f16 = static_cast <const sycl::half *>(src0->data );
2739
+ float * dst_ddf = static_cast <float *>(dst->data );
2740
+
2741
+ const sycl::half * src1_f16 = static_cast <const sycl::half *>(src1->data );
2742
+ const size_t type_size_src1 = ggml_type_size (src1->type );
2743
+ GGML_ASSERT (nb10 == type_size_src1);
2744
+
2745
+ // SRC1 strides
2746
+ int64_t s11 = nb11 / type_size_src1;
2747
+ int64_t s12 = nb12 / type_size_src1;
2748
+ int64_t s13 = nb13 / type_size_src1;
2743
2749
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc (ctx.pool ());
2750
+
2751
+ // convert src1 to fp16
2744
2752
if (src1->type != GGML_TYPE_F16) {
2745
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl (src1->type , dst);
2753
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl (src1->type );
2754
+ GGML_ASSERT (to_fp16_nc_sycl != nullptr );
2746
2755
const int64_t ne_src1 = ggml_nelements (src1);
2747
2756
src1_f16_alloc.alloc (ne_src1);
2748
- GGML_ASSERT (to_fp16_sycl != nullptr );
2749
- to_fp16_sycl (src1_ddf, src1_f16_alloc.get (), ne_src1, main_stream);
2757
+ to_fp16_nc_sycl (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2758
+
2759
+ src1_f16 = src1_f16_alloc.get ();
2760
+ s11 = ne10;
2761
+ s12 = ne11 * s11;
2762
+ s13 = ne12 * s12;
2750
2763
}
2751
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
2752
- : src1_f16_alloc.get ();
2753
2764
2754
- char * dst_t ;
2765
+ ggml_sycl_pool_alloc<sycl::half> dst_f16 (ctx.pool ());
2766
+ char * dst_t = reinterpret_cast <char *>(dst_ddf);
2755
2767
2756
- dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
2757
- dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
2768
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t ::real_float;
2769
+ dpct::library_data_t mkl_data_type = dpct::library_data_t ::real_float;
2758
2770
2759
2771
// dst strides
2760
2772
size_t nbd2 = dst->nb [2 ];
2761
2773
size_t nbd3 = dst->nb [3 ];
2762
2774
2763
2775
const float alpha_f32 = 1 .0f ;
2764
- const float beta_f32 = 0 .0f ;
2776
+ const float beta_f32 = 0 .0f ;
2765
2777
2766
2778
const void * alpha = &alpha_f32;
2767
2779
const void * beta = &beta_f32;
2768
2780
2769
- dst_t = (char *) dst_ddf;
2770
-
2771
2781
GGML_ASSERT (ne12 % ne02 == 0 );
2772
2782
GGML_ASSERT (ne13 % ne03 == 0 );
2773
2783
2774
2784
// broadcast factors
2775
- const int64_t r2 = ne12/ ne02;
2776
- const int64_t r3 = ne13/ ne03;
2785
+ const int64_t r2 = ne12 / ne02;
2786
+ const int64_t r3 = ne13 / ne03;
2777
2787
2778
2788
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
2779
2789
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
2780
- SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (
2781
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2782
- ( const char *) src0_as_f16 , dpct::library_data_t ::real_half, nb01 / nb00, nb02 / nb00,
2783
- ( const char *) src1_f16, dpct::library_data_t ::real_half, nb11 / nb10, nb12 / nb10 , beta, ( char *) dst_t ,
2784
- cu_data_type, ne01, nb2 / nb0 , ne12 * ne13, cu_compute_type )));
2790
+ SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (*queue, oneapi::math::transpose::trans,
2791
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2792
+ src0_f16 , dpct::library_data_t ::real_half, nb01 / nb00, nb02 / nb00,
2793
+ src1_f16, dpct::library_data_t ::real_half, s11, s12 , beta, dst_t ,
2794
+ mkl_data_type, ne0, ne1 * ne0 , ne12 * ne13, mkl_compute_type )));
2785
2795
} else {
2786
- const int ne23 = ne12* ne13;
2796
+ const int ne23 = ne12 * ne13;
2787
2797
2788
- ggml_sycl_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 * ne23);
2789
- ggml_sycl_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 * ne23);
2798
+ ggml_sycl_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 * ne23);
2799
+ ggml_sycl_pool_alloc<void *> ptrs_dst (ctx.pool (), 1 * ne23);
2790
2800
ggml_sycl_pool_alloc<matrix_info_t <float >> matrix_info (ctx.host_pool (), 1 );
2791
2801
2792
2802
sycl::range<3 > block_dims (1 , ne12, ne13);
2793
- /*
2794
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
2795
- the limit. To get the device limit, query
2796
- info::device::max_work_group_size. Adjust the work-group size if needed.
2797
- */
2798
- {
2799
- dpct::has_capability_or_fail (main_stream->get_device (),
2800
- {sycl::aspect::fp16});
2801
-
2802
- main_stream->submit ([&](sycl::handler &cgh) {
2803
- const void **ptrs_src_get = ptrs_src.get ();
2804
- void **ptrs_dst_get = ptrs_dst.get ();
2805
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2 ;
2806
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2 ;
2807
- cgh.parallel_for (sycl::nd_range<3 >(block_dims, block_dims),
2808
- [=](sycl::nd_item<3 > item_ct1) {
2809
- k_compute_batched_ptrs (
2810
- src0_as_f16, src1_f16,
2811
- dst_t , ptrs_src_get,
2812
- ptrs_dst_get, ne12, ne13, ne23,
2813
- nb02, nb03, nb12_scaled, nb13_scaled,
2814
- nbd2, nbd3, r2, r3, item_ct1);
2815
- });
2803
+ queue->submit ([&](sycl::handler & cgh) {
2804
+ const void ** ptrs_src_get = ptrs_src.get ();
2805
+ void ** ptrs_dst_get = ptrs_dst.get ();
2806
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof (sycl::half);
2807
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof (sycl::half);
2808
+ cgh.parallel_for (sycl::nd_range<3 >(block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
2809
+ k_compute_batched_ptrs (src0_f16, src1_f16, dst_t , ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2810
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2816
2811
});
2817
- }
2812
+ });
2813
+
2818
2814
SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (
2819
- *main_stream , oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2815
+ *queue , oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2820
2816
(const void **) (ptrs_src.get () + 0 * ne23), dpct::library_data_t ::real_half, nb01 / nb00,
2821
- (const void **) (ptrs_src.get () + 1 * ne23), dpct::library_data_t ::real_half, nb11 / nb10 , beta,
2822
- (void **) (ptrs_dst.get () + 0 * ne23), cu_data_type, ne01 , ne23, cu_compute_type , matrix_info.get ())));
2817
+ (const void **) (ptrs_src.get () + 1 * ne23), dpct::library_data_t ::real_half, s11 , beta,
2818
+ (void **) (ptrs_dst.get () + 0 * ne23), mkl_data_type, ne0 , ne23, mkl_compute_type , matrix_info.get ())));
2823
2819
}
2824
- }
2825
- catch (sycl::exception const &exc) {
2826
- std::cerr << exc.what () << " Exception caught at file:" << __FILE__
2827
- << " , line:" << __LINE__ << std::endl;
2828
- std::exit (1 );
2820
+ } catch (const sycl::exception & exc) {
2821
+ std::cerr << exc.what () << " Exception caught at file:" << __FILE__ << " , line:" << __LINE__ << std::endl;
2822
+ std::exit (1 );
2829
2823
}
2830
2824
2831
2825
inline bool ggml_sycl_supports_mmq (enum ggml_type type) {
@@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2966
2960
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
2967
2961
ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
2968
2962
}
2969
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
2963
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
2970
2964
// KQV single-batch
2971
2965
ggml_sycl_mul_mat_vec_nc (ctx, src0, src1, dst);
2972
2966
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
@@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
3873
3867
if (a->ne [3 ] != b->ne [3 ]) {
3874
3868
return false ;
3875
3869
}
3876
- if (!ggml_is_contiguous (b)) {
3877
- return false ;
3878
- }
3879
3870
ggml_type a_type = a->type ;
3880
3871
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
3881
3872
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
0 commit comments