Skip to content

Commit 685e02b

Browse files
committed
sycl: reordered Q4_K MMVQ
1 parent d61dda3 commit 685e02b

File tree

7 files changed

+273
-82
lines changed

7 files changed

+273
-82
lines changed

ggml/src/ggml-sycl/convert.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
183183
}
184184
}
185185

186+
template <typename dst_t>
187+
static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
188+
const int64_t nb = k / QK_K;
189+
const size_t local_size = 32;
190+
const size_t global_size = nb * local_size;
191+
192+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
193+
194+
stream->submit([&](sycl::handler & cgh) {
195+
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
196+
197+
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
198+
[=](sycl::nd_item<1> item_ct1) {
199+
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
200+
});
201+
});
202+
}
203+
186204
template <typename dst_t>
187205
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
188206
dpct::queue_ptr stream) {
@@ -493,7 +511,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
493511
case GGML_TYPE_Q3_K:
494512
return dequantize_row_q3_K_sycl;
495513
case GGML_TYPE_Q4_K:
496-
return dequantize_row_q4_K_sycl;
514+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
515+
return dequantize_row_q4_K_sycl_reorder;
516+
} else {
517+
return dequantize_row_q4_K_sycl;
518+
}
497519
case GGML_TYPE_Q5_K:
498520
return dequantize_row_q5_K_sycl;
499521
case GGML_TYPE_Q6_K:

ggml/src/ggml-sycl/dequantize.hpp

+59-21
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
357357
}
358358
#endif
359359

360+
template <typename dst_t>
361+
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
362+
const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
363+
const int is = 2 * il;
364+
constexpr int n = 4;
365+
366+
uint8_t sc, m;
367+
get_scale_min_k4(is + 0, scales_local, sc, m);
368+
const float d1 = dall * sc;
369+
const float m1 = dmin * m;
370+
371+
get_scale_min_k4(is + 1, scales_local, sc, m);
372+
const float d2 = dall * sc;
373+
const float m2 = dmin * m;
374+
375+
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
376+
for (int l = 0; l < n; ++l) {
377+
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
378+
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
379+
}
380+
}
381+
360382
template<typename dst_t>
361383
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
362384
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
365387
const int64_t i = item_ct1.get_group(2);
366388

367389
#if QK_K == 256
368-
// assume 32 threads
369390
const int64_t tid = item_ct1.get_local_id(2);
370-
const int64_t il = tid/8;
371-
const int64_t ir = tid%8;
372-
const int64_t is = 2*il;
373-
const int64_t n = 4;
391+
const int64_t il = tid / 8;
392+
const int64_t ir = tid % 8;
374393

375-
dst_t * y = yy + i*QK_K + 64*il + n*ir;
394+
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
376395

377396
const sycl::half2 dm = x[i].dm;
378397
const float dall = dm[0];
379398
const float dmin = dm[1];
380399

381-
if (tid < 12)
400+
if (tid < 12) {
382401
scales_local[tid] = x[i].scales[tid];
383-
item_ct1.barrier(sycl::access::fence_space::local_space);
384-
385-
uint8_t sc, m;
386-
get_scale_min_k4(is + 0, scales_local, sc, m);
387-
const float d1 = dall * sc;
388-
const float m1 = dmin * m;
389-
get_scale_min_k4(is + 1, scales_local, sc, m);
390-
const float d2 = dall * sc;
391-
const float m2 = dmin * m;
392-
393-
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
394-
for (int l = 0; l < n; ++l) {
395-
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
396-
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
397402
}
403+
404+
item_ct1.barrier(sycl::access::fence_space::local_space);
405+
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
398406
#else
399407
const int64_t tid = item_ct1.get_local_id(2);
400408
const uint8_t * q = x[i].qs;
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
406414
#endif
407415
}
408416

417+
template <typename dst_t>
418+
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
419+
const sycl::nd_item<1> & item_ct1, int64_t nb) {
420+
const int64_t i = item_ct1.get_group(0); // block index
421+
const int64_t tid = item_ct1.get_local_id(0); // thread index within block
422+
const int64_t il = tid / 8;
423+
const int64_t ir = tid % 8;
424+
425+
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
426+
427+
const uint8_t * base = static_cast<const uint8_t *>(vx);
428+
const size_t qs_offset = i * (QK_K / 2);
429+
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
430+
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
431+
432+
const uint8_t * qs_ptr = base + qs_offset;
433+
const uint8_t * scales_ptr = base + scales_offset;
434+
ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
435+
436+
const float dall = dm_values.x();
437+
const float dmin = dm_values.y();
438+
439+
if (tid < 12) {
440+
scales_local[tid] = scales_ptr[tid];
441+
}
442+
443+
item_ct1.barrier(sycl::access::fence_space::local_space);
444+
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
445+
}
446+
409447
template<typename dst_t>
410448
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
411449
const sycl::nd_item<3> &item_ct1) {

ggml/src/ggml-sycl/dmmv.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
11291129
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
11301130
break;
11311131
case GGML_TYPE_Q4_K:
1132-
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1132+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1133+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1134+
// reorder is currently not supported for dmmv
1135+
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
1136+
} else {
1137+
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1138+
}
11331139
break;
11341140
case GGML_TYPE_Q5_K:
11351141
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);

ggml/src/ggml-sycl/ggml-sycl.cpp

+64-14
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
341341
assert(tensor->view_src->buffer->buft == buffer->buft);
342342
return GGML_STATUS_SUCCESS;
343343
}
344-
if (tensor->type == GGML_TYPE_Q4_0) {
344+
if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) {
345345
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
346346
tensor->extra = extra;
347347
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -2840,6 +2840,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28402840
inline bool ggml_sycl_supports_reorder_dequantize(enum ggml_type type) {
28412841
switch (type) {
28422842
case GGML_TYPE_Q4_0:
2843+
case GGML_TYPE_Q4_K:
28432844
return true;
28442845
default:
28452846
return false;
@@ -2858,6 +2859,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
28582859
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
28592860
switch (type) {
28602861
case GGML_TYPE_Q4_0:
2862+
case GGML_TYPE_Q4_K:
28612863
return true;
28622864
default:
28632865
return false;
@@ -2883,16 +2885,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
28832885
}
28842886
}
28852887

2886-
static void reorder_qw(char *data_device, const int ncols, const int nrows,
2887-
size_t size, size_t offset, dpct::queue_ptr stream) {
2888-
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
2888+
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2889+
dpct::queue_ptr stream) {
2890+
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
28892891
SYCL_CHECK(
28902892
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
28912893
.wait()));
28922894
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
28932895
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
28942896
int offset_blks = offset / sizeof(block_q4_0);
2895-
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
2897+
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
28962898
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
28972899

28982900
stream->parallel_for(
@@ -2911,13 +2913,54 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
29112913
sycl::free(tmp_buf, *stream);
29122914
}
29132915

2916+
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2917+
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
2918+
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
2919+
2920+
const int nblocks = size / sizeof(block_q4_K);
2921+
2922+
auto * tmp_buf = sycl::malloc_device<uint8_t>(size, *stream);
2923+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
2924+
2925+
auto * qs_ptr = data_device;
2926+
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2927+
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2928+
2929+
stream->parallel_for(nblocks, [=](auto i) {
2930+
const block_q4_K * x = (const block_q4_K *) tmp_buf;
2931+
const int ib = i;
2932+
2933+
for (int j = 0; j < QK_K / 2; ++j) {
2934+
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
2935+
}
2936+
2937+
for (int j = 0; j < K_SCALE_SIZE; ++j) {
2938+
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
2939+
}
2940+
2941+
dm_ptr[ib] = x[ib].dm;
2942+
});
2943+
2944+
sycl::free(tmp_buf, *stream);
2945+
}
2946+
29142947
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
2915-
char*data_device = (char*)src0->data;
2948+
uint8_t * data_device = (uint8_t *) src0->data;
29162949
size_t ncols = src0->ne[0];
29172950
size_t nrows = src0->ne[1];
29182951
size_t size = ggml_nbytes(src0);
29192952

2920-
reorder_qw(data_device, ncols, nrows, size, 0, stream);
2953+
switch (src0->type) {
2954+
case GGML_TYPE_Q4_0:
2955+
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
2956+
break;
2957+
case GGML_TYPE_Q4_K:
2958+
reorder_qw_q4_k(data_device, size, 0, stream);
2959+
break;
2960+
default:
2961+
GGML_ABORT("reorder_qw() called with unsupported type");
2962+
break;
2963+
}
29212964
}
29222965

29232966
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -2943,8 +2986,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
29432986
}
29442987
}
29452988

2946-
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29472989

2990+
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2991+
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2992+
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
2993+
}
2994+
2995+
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2996+
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2997+
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2998+
}
2999+
3000+
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29483001
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
29493002
int64_t min_compute_capability = INT_MAX;
29503003

@@ -2966,14 +3019,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29663019
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
29673020
}
29683021

3022+
// TODO: make these into functions, add mmvq check for reorder
29693023
// check data types and tensor shapes for custom matrix multiplication kernels:
2970-
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
2971-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2972-
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3024+
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
29733025

2974-
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
2975-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2976-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3026+
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
29773027

29783028
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
29793029
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

ggml/src/ggml-sycl/mmvq.cpp

+29-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
2424
const int blocks_per_row = ncols / block_traits::qk;
2525
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
2626
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
27+
const int nblocks = nrows * (ncols / block_traits::qk);
2728

2829
static_assert(blocks_per_subgroup > 0);
2930
static_assert(block_elements_per_subgroup > 0);
@@ -44,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4445
// x block quant index when casting the quants to int
4546
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
4647

47-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
48+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
4849
}
4950
}
5051

@@ -738,6 +739,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
738739
}
739740
}
740741

742+
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
743+
const int nrows, dpct::queue_ptr stream) {
744+
GGML_ASSERT(ncols % QK_K == 0);
745+
746+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
747+
constexpr size_t num_subgroups = 16;
748+
GGML_ASSERT(block_num_y % num_subgroups == 0);
749+
750+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
751+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
752+
753+
stream->submit([&](sycl::handler & cgh) {
754+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
755+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
756+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
757+
nrows, nd_item);
758+
});
759+
});
760+
}
761+
762+
741763
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
742764
float *dst, const int ncols,
743765
const int nrows,
@@ -1034,7 +1056,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10341056
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
10351057
break;
10361058
case GGML_TYPE_Q4_K:
1037-
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1059+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1060+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1061+
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1062+
} else {
1063+
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1064+
}
10381065
break;
10391066
case GGML_TYPE_Q5_K:
10401067
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

ggml/src/ggml-sycl/quants.hpp

+22
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
5656
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
5757
};
5858

59+
template <> struct block_q_t<GGML_TYPE_Q4_K> {
60+
struct traits {
61+
static constexpr uint32_t qk = QK_K;
62+
static constexpr uint32_t qi = QI4_K;
63+
static constexpr uint32_t qr = QR4_K;
64+
static constexpr uint32_t vdr_mmvq = 2;
65+
};
66+
67+
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
68+
69+
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
70+
auto nblocks = (nrows * (ncols / traits::qk));
71+
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
72+
}
73+
74+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75+
76+
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77+
78+
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79+
};
80+
5981
} // namespace ggml_sycl_reordered
6082

6183
#endif // GGML_SYCL_QUANTS_HPP

0 commit comments

Comments
 (0)