@@ -341,7 +341,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
341
341
assert (tensor->view_src ->buffer ->buft == buffer->buft );
342
342
return GGML_STATUS_SUCCESS;
343
343
}
344
- if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
344
+ if (( tensor->type == GGML_TYPE_Q4_0 || tensor-> type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
345
345
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
346
346
tensor->extra = extra;
347
347
ctx->tensor_extras .push_back (extra); // used to release it when destroy ctx.
@@ -2840,6 +2840,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2840
2840
inline bool ggml_sycl_supports_reorder_mul_mat_sycl (enum ggml_type type) {
2841
2841
switch (type) {
2842
2842
case GGML_TYPE_Q4_0:
2843
+ case GGML_TYPE_Q4_K:
2843
2844
return true ;
2844
2845
default :
2845
2846
return false ;
@@ -2858,6 +2859,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2858
2859
inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
2859
2860
switch (type) {
2860
2861
case GGML_TYPE_Q4_0:
2862
+ case GGML_TYPE_Q4_K:
2861
2863
return true ;
2862
2864
default :
2863
2865
return false ;
@@ -2883,16 +2885,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2883
2885
}
2884
2886
}
2885
2887
2886
- static void reorder_qw ( char * data_device, const int ncols, const int nrows,
2887
- size_t size, size_t offset, dpct::queue_ptr stream) {
2888
- auto tmp_buf = sycl::malloc_shared<char >(size, *stream);
2888
+ static void reorder_qw_q4_0 ( uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset ,
2889
+ dpct::queue_ptr stream) {
2890
+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
2889
2891
SYCL_CHECK (
2890
2892
CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
2891
2893
.wait ()));
2892
2894
GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
2893
2895
GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
2894
2896
int offset_blks = offset / sizeof (block_q4_0);
2895
- auto qs_ptr = ( uint8_t *) data_device + offset_blks * QK4_0 / 2 ;
2897
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2 ;
2896
2898
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
2897
2899
2898
2900
stream->parallel_for (
@@ -2906,18 +2908,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
2906
2908
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs [j];
2907
2909
}
2908
2910
*(d_ptr + ib) = x[ib].d ;
2909
- });
2911
+ }).wait_and_throw ();
2912
+
2913
+ sycl::free (tmp_buf, *stream);
2914
+ }
2915
+
2916
+ static void reorder_qw_q4_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2917
+ GGML_ASSERT (size % sizeof (block_q4_K) == 0 );
2918
+ GGML_ASSERT (offset % sizeof (block_q4_K) == 0 );
2919
+
2920
+ const int nblocks = size / sizeof (block_q4_K);
2921
+
2922
+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
2923
+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
2924
+
2925
+ auto * qs_ptr = data_device;
2926
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2927
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2928
+
2929
+ stream->parallel_for (nblocks, [=](auto i) {
2930
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2931
+ const int ib = i;
2932
+
2933
+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
2934
+ qs_ptr[ib * (QK_K / 2 ) + j] = x[ib].qs [j];
2935
+ }
2936
+
2937
+ for (int j = 0 ; j < K_SCALE_SIZE; ++j) {
2938
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales [j];
2939
+ }
2940
+
2941
+ dm_ptr[ib] = x[ib].dm ;
2942
+ }).wait_and_throw ();
2910
2943
2911
2944
sycl::free (tmp_buf, *stream);
2912
2945
}
2913
2946
2914
2947
static void reorder_qw (const ggml_tensor * src0, dpct::queue_ptr stream) {
2915
- char * data_device = (char *) src0->data ;
2948
+ uint8_t * data_device = (uint8_t *) src0->data ;
2916
2949
size_t ncols = src0->ne [0 ];
2917
2950
size_t nrows = src0->ne [1 ];
2918
2951
size_t size = ggml_nbytes (src0);
2919
2952
2920
- reorder_qw (data_device, ncols, nrows, size, 0 , stream);
2953
+ switch (src0->type ) {
2954
+ case GGML_TYPE_Q4_0:
2955
+ reorder_qw_q4_0 (data_device, ncols, nrows, size, 0 , stream);
2956
+ break ;
2957
+ case GGML_TYPE_Q4_K:
2958
+ reorder_qw_q4_k (data_device, size, 0 , stream);
2959
+ break ;
2960
+ default :
2961
+ GGML_ABORT (" reorder_qw() called with unsupported type" );
2962
+ break ;
2963
+ }
2921
2964
}
2922
2965
2923
2966
static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -2960,8 +3003,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
2960
3003
extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
2961
3004
}
2962
3005
2963
- static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2964
3006
3007
+ static bool can_use_dequantize_mul_mat_vec (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3008
+ return ggml_sycl_supports_dmmv (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3009
+ src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
3010
+ }
3011
+
3012
+ static bool can_use_mul_mat_vec_q (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3013
+ return ggml_is_quantized (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3014
+ src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3015
+ }
3016
+
3017
+ static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2965
3018
const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
2966
3019
int64_t min_compute_capability = INT_MAX;
2967
3020
@@ -2983,14 +3036,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2983
3036
min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2984
3037
}
2985
3038
3039
+ // TODO: make these into functions, add mmvq check for reorder
2986
3040
// check data types and tensor shapes for custom matrix multiplication kernels:
2987
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv (src0->type )
2988
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2989
- && src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
3041
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec (src0, src1, dst);
2990
3042
2991
- bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
2992
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2993
- && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3043
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q (src0, src1, dst);
2994
3044
2995
3045
bool use_mul_mat_q = ggml_sycl_supports_mmq (src0->type )
2996
3046
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
0 commit comments