Skip to content

Commit ae199a7

Browse files
committed
sycl : Implemented reorder Q4_0 mmvq
Signed-off-by: Alberto Cabrera <[email protected]>
1 parent 02082f1 commit ae199a7

File tree

6 files changed

+307
-120
lines changed

6 files changed

+307
-120
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,24 @@
1313
#ifndef GGML_SYCL_BACKEND_HPP
1414
#define GGML_SYCL_BACKEND_HPP
1515

16-
#include "concat.hpp"
1716
#include "common.hpp"
17+
#include "concat.hpp"
1818
#include "conv.hpp"
1919
#include "convert.hpp"
20+
#include "cpy.hpp"
2021
#include "dequantize.hpp"
2122
#include "dmmv.hpp"
23+
#include "element_wise.hpp"
24+
#include "gla.hpp"
25+
#include "im2col.hpp"
2226
#include "mmq.hpp"
2327
#include "mmvq.hpp"
24-
#include "rope.hpp"
2528
#include "norm.hpp"
29+
#include "outprod.hpp"
30+
#include "quants.hpp"
31+
#include "rope.hpp"
2632
#include "softmax.hpp"
2733
#include "tsembd.hpp"
28-
#include "im2col.hpp"
2934
#include "wkv.hpp"
30-
#include "outprod.hpp"
31-
#include "element_wise.hpp"
32-
#include "cpy.hpp"
33-
#include "gla.hpp"
3435

35-
#endif // GGML_SYCL_BACKEND_HPP
36+
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/common.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,4 +788,9 @@ bool gpu_has_xmx(sycl::device &dev);
788788
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
789789
const ggml_tensor *src1, ggml_tensor *dst,
790790
const ggml_sycl_op_flatten_t op);
791+
792+
constexpr size_t safe_div(const size_t m, const size_t n) {
793+
return (m + n - 1) / n;
794+
}
795+
791796
#endif // GGML_SYCL_COMMON_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,6 +2923,15 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
29232923
return false;
29242924
}
29252925

2926+
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2927+
switch (type) {
2928+
case GGML_TYPE_Q4_0:
2929+
return true;
2930+
default:
2931+
return false;
2932+
}
2933+
}
2934+
29262935
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29272936
switch (type) {
29282937
case GGML_TYPE_Q4_0:
@@ -2942,13 +2951,14 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29422951
}
29432952
}
29442953

2945-
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2946-
2947-
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2948-
int64_t min_compute_capability = INT_MAX;
2954+
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2955+
ggml_tensor * dst) {
2956+
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2957+
int64_t min_compute_capability = INT_MAX;
29492958

29502959
if (split) {
2951-
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2960+
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2961+
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
29522962
auto & tensor_split = buft_ctx->tensor_split;
29532963
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
29542964
// skip devices that are not going to do any work:
@@ -2961,7 +2971,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29612971
}
29622972
}
29632973
} else {
2964-
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
2974+
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
29652975
}
29662976

29672977
// check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2984,8 +2994,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29842994
#endif // SYCL_USE_XMX
29852995

29862996
// mmvq path is faster in the CUDA backend.
2987-
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
2997+
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
2998+
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
2999+
// is enabled takes precedence over DMMV, the current if-else implementation
3000+
// requires disabling DMMV if both conditions are met
3001+
|| (ctx.opt_feature.reorder && ggml_sycl_supports_reorder_mmvq(src0->type))) {
29883002
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3003+
}
29893004

29903005
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
29913006
// TODO: Refactor and cleanup of mul mat dispatching.
@@ -3004,14 +3019,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
30043019
// KQ + KQV multi-batch
30053020
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
30063021
} else if (use_dequantize_mul_mat_vec) {
3007-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3008-
// save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3022+
constexpr bool convert_src1_to_q8_1 = false;
3023+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
30093024
} else if (use_mul_mat_vec_q) {
3010-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3025+
constexpr bool convert_src1_to_q8_1 = true;
3026+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
30113027
} else if (use_mul_mat_q) {
3012-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3028+
constexpr bool convert_src1_to_q8_1 = true;
3029+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
30133030
} else {
3014-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3031+
constexpr bool convert_src1_to_q8_1 = false;
3032+
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
30153033
}
30163034
}
30173035

0 commit comments

Comments
 (0)