@@ -2887,6 +2887,15 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2887
2887
return false ;
2888
2888
}
2889
2889
2890
+ inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
2891
+ switch (type) {
2892
+ case GGML_TYPE_Q4_0:
2893
+ return true ;
2894
+ default :
2895
+ return false ;
2896
+ }
2897
+ }
2898
+
2890
2899
static bool ggml_sycl_supports_dmmv (enum ggml_type type) {
2891
2900
switch (type) {
2892
2901
case GGML_TYPE_Q4_0:
@@ -2906,13 +2915,14 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2906
2915
}
2907
2916
}
2908
2917
2909
- static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2910
-
2911
- const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
2912
- int64_t min_compute_capability = INT_MAX;
2918
+ static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2919
+ ggml_tensor * dst) {
2920
+ const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
2921
+ int64_t min_compute_capability = INT_MAX;
2913
2922
2914
2923
if (split) {
2915
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2924
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2925
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2916
2926
auto & tensor_split = buft_ctx->tensor_split ;
2917
2927
for (int id = 0 ; id < ggml_sycl_info ().device_count ; ++id) {
2918
2928
// skip devices that are not going to do any work:
@@ -2925,7 +2935,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2925
2935
}
2926
2936
}
2927
2937
} else {
2928
- min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2938
+ min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2929
2939
}
2930
2940
2931
2941
// check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2948,8 +2958,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2948
2958
#endif // SYCL_USE_XMX
2949
2959
2950
2960
// mmvq path is faster in the CUDA backend.
2951
- if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda)
2961
+ if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda
2962
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
2963
+ // is enabled takes precedence over DMMV, the current if-else implementation
2964
+ // requires disabling DMMV if both conditions are met
2965
+ || (ctx.opt_feature .reorder && ggml_sycl_supports_reorder_mmvq (src0->type ))) {
2952
2966
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
2967
+ }
2953
2968
2954
2969
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
2955
2970
// TODO: Refactor and cleanup of mul mat dispatching.
@@ -2968,14 +2983,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2968
2983
// KQ + KQV multi-batch
2969
2984
ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
2970
2985
} else if (use_dequantize_mul_mat_vec) {
2971
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false ) ;
2972
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream() );
2986
+ constexpr bool convert_src1_to_q8_1 = false ;
2987
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1 );
2973
2988
} else if (use_mul_mat_vec_q) {
2974
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true );
2989
+ constexpr bool convert_src1_to_q8_1 = true ;
2990
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
2975
2991
} else if (use_mul_mat_q) {
2976
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true );
2992
+ constexpr bool convert_src1_to_q8_1 = true ;
2993
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
2977
2994
} else {
2978
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false );
2995
+ constexpr bool convert_src1_to_q8_1 = false ;
2996
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
2979
2997
}
2980
2998
}
2981
2999
0 commit comments