Skip to content

Commit 187451b

Browse files
committed
sycl : Implemented reorder Q4_0 mmvq
Signed-off-by: Alberto Cabrera <[email protected]>
1 parent 11d07e1 commit 187451b

File tree

6 files changed

+307
-120
lines changed

6 files changed

+307
-120
lines changed

ggml/src/ggml-sycl/backend.hpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,24 @@
1313
#ifndef GGML_SYCL_BACKEND_HPP
1414
#define GGML_SYCL_BACKEND_HPP
1515

16-
#include "concat.hpp"
1716
#include "common.hpp"
17+
#include "concat.hpp"
1818
#include "conv.hpp"
1919
#include "convert.hpp"
20+
#include "cpy.hpp"
2021
#include "dequantize.hpp"
2122
#include "dmmv.hpp"
23+
#include "element_wise.hpp"
24+
#include "gla.hpp"
25+
#include "im2col.hpp"
2226
#include "mmq.hpp"
2327
#include "mmvq.hpp"
24-
#include "rope.hpp"
2528
#include "norm.hpp"
29+
#include "outprod.hpp"
30+
#include "quants.hpp"
31+
#include "rope.hpp"
2632
#include "softmax.hpp"
2733
#include "tsembd.hpp"
28-
#include "im2col.hpp"
2934
#include "wkv.hpp"
30-
#include "outprod.hpp"
31-
#include "element_wise.hpp"
32-
#include "cpy.hpp"
33-
#include "gla.hpp"
3435

35-
#endif // GGML_SYCL_BACKEND_HPP
36+
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/common.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -776,4 +776,9 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
776776
}
777777

778778
bool gpu_has_xmx(sycl::device &dev);
779+
780+
constexpr size_t safe_div(const size_t m, const size_t n) {
781+
return (m + n - 1) / n;
782+
}
783+
779784
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

+30-12
Original file line numberDiff line numberDiff line change
@@ -2887,6 +2887,15 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28872887
return false;
28882888
}
28892889

2890+
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2891+
switch (type) {
2892+
case GGML_TYPE_Q4_0:
2893+
return true;
2894+
default:
2895+
return false;
2896+
}
2897+
}
2898+
28902899
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
28912900
switch (type) {
28922901
case GGML_TYPE_Q4_0:
@@ -2906,13 +2915,14 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29062915
}
29072916
}
29082917

2909-
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2910-
2911-
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2912-
int64_t min_compute_capability = INT_MAX;
2918+
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2919+
ggml_tensor * dst) {
2920+
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2921+
int64_t min_compute_capability = INT_MAX;
29132922

29142923
if (split) {
2915-
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2924+
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2925+
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
29162926
auto & tensor_split = buft_ctx->tensor_split;
29172927
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
29182928
// skip devices that are not going to do any work:
@@ -2925,7 +2935,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29252935
}
29262936
}
29272937
} else {
2928-
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
2938+
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
29292939
}
29302940

29312941
// check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2948,8 +2958,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29482958
#endif // SYCL_USE_XMX
29492959

29502960
// mmvq path is faster in the CUDA backend.
2951-
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
2961+
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
2962+
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
2963+
// is enabled takes precedence over DMMV, the current if-else implementation
2964+
// requires disabling DMMV if both conditions are met
2965+
|| (ctx.opt_feature.reorder && ggml_sycl_supports_reorder_mmvq(src0->type))) {
29522966
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
2967+
}
29532968

29542969
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
29552970
// TODO: Refactor and cleanup of mul mat dispatching.
@@ -2968,14 +2983,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29682983
// KQ + KQV multi-batch
29692984
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
29702985
} else if (use_dequantize_mul_mat_vec) {
2971-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
2972-
// save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
2986+
constexpr bool convert_src1_to_q8_1 = false;
2987+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
29732988
} else if (use_mul_mat_vec_q) {
2974-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
2989+
constexpr bool convert_src1_to_q8_1 = true;
2990+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
29752991
} else if (use_mul_mat_q) {
2976-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
2992+
constexpr bool convert_src1_to_q8_1 = true;
2993+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
29772994
} else {
2978-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
2995+
constexpr bool convert_src1_to_q8_1 = false;
2996+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
29792997
}
29802998
}
29812999

0 commit comments

Comments
 (0)