@@ -2923,6 +2923,15 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2923
2923
return false ;
2924
2924
}
2925
2925
2926
+ inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
2927
+ switch (type) {
2928
+ case GGML_TYPE_Q4_0:
2929
+ return true ;
2930
+ default :
2931
+ return false ;
2932
+ }
2933
+ }
2934
+
2926
2935
static bool ggml_sycl_supports_dmmv (enum ggml_type type) {
2927
2936
switch (type) {
2928
2937
case GGML_TYPE_Q4_0:
@@ -2942,13 +2951,14 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2942
2951
}
2943
2952
}
2944
2953
2945
- static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2946
-
2947
- const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
2948
- int64_t min_compute_capability = INT_MAX;
2954
+ static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2955
+ ggml_tensor * dst) {
2956
+ const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
2957
+ int64_t min_compute_capability = INT_MAX;
2949
2958
2950
2959
if (split) {
2951
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2960
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2961
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2952
2962
auto & tensor_split = buft_ctx->tensor_split ;
2953
2963
for (int id = 0 ; id < ggml_sycl_info ().device_count ; ++id) {
2954
2964
// skip devices that are not going to do any work:
@@ -2961,7 +2971,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2961
2971
}
2962
2972
}
2963
2973
} else {
2964
- min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2974
+ min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2965
2975
}
2966
2976
2967
2977
// check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2984,8 +2994,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2984
2994
#endif // SYCL_USE_XMX
2985
2995
2986
2996
// mmvq path is faster in the CUDA backend.
2987
- if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda)
2997
+ if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda
2998
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
2999
+ // is enabled takes precedence over DMMV, the current if-else implementation
3000
+ // requires disabling DMMV if both conditions are met
3001
+ || (ctx.opt_feature .reorder && ggml_sycl_supports_reorder_mmvq (src0->type ))) {
2988
3002
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3003
+ }
2989
3004
2990
3005
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
2991
3006
// TODO: Refactor and cleanup of mul mat dispatching.
@@ -3004,14 +3019,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3004
3019
// KQ + KQV multi-batch
3005
3020
ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
3006
3021
} else if (use_dequantize_mul_mat_vec) {
3007
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false ) ;
3008
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream() );
3022
+ constexpr bool convert_src1_to_q8_1 = false ;
3023
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1 );
3009
3024
} else if (use_mul_mat_vec_q) {
3010
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true );
3025
+ constexpr bool convert_src1_to_q8_1 = true ;
3026
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3011
3027
} else if (use_mul_mat_q) {
3012
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true );
3028
+ constexpr bool convert_src1_to_q8_1 = true ;
3029
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3013
3030
} else {
3014
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false );
3031
+ constexpr bool convert_src1_to_q8_1 = false ;
3032
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3015
3033
}
3016
3034
}
3017
3035
0 commit comments