Skip to content

Commit 4cce5e9

Browse files
gcpggerganov
authored andcommitted
metal : copy kernels for quant to F32/F16 conversions (ggml-org#12017)
metal: use dequantize_q templates --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 3e841d8 commit 4cce5e9

File tree

2 files changed

+120
-5
lines changed

2 files changed

+120
-5
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
410420
GGML_METAL_KERNEL_TYPE_CONCAT,
411421
GGML_METAL_KERNEL_TYPE_SQR,
412422
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass
10121022
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
10131023
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10141024
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1025+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1027+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1029+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1031+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1033+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
10151035
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10161036
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10171037
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871307
default:
12881308
return false;
12891309
}
1310+
case GGML_TYPE_Q4_0:
1311+
case GGML_TYPE_Q4_1:
1312+
case GGML_TYPE_Q5_0:
1313+
case GGML_TYPE_Q5_1:
1314+
case GGML_TYPE_Q8_0:
1315+
switch (op->type) {
1316+
case GGML_TYPE_F32:
1317+
case GGML_TYPE_F16:
1318+
return true;
1319+
default:
1320+
return false;
1321+
}
12901322
default:
12911323
return false;
12921324
};
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
38993931
case GGML_OP_CPY:
39003932
case GGML_OP_CONT:
39013933
{
3902-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3903-
3904-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3905-
39063934
id<MTLComputePipelineState> pipeline = nil;
39073935

39083936
switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
39363964
switch (dstt) {
39373965
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
39383966
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3939-
default: GGML_ASSERT(false && "not implemented");
3967+
default: GGML_ABORT("not implemented");
3968+
};
3969+
} break;
3970+
case GGML_TYPE_Q4_0:
3971+
{
3972+
switch (dstt) {
3973+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3974+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3975+
default: GGML_ABORT("not implemented");
3976+
};
3977+
} break;
3978+
case GGML_TYPE_Q4_1:
3979+
{
3980+
switch (dstt) {
3981+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3982+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3983+
default: GGML_ABORT("not implemented");
3984+
};
3985+
} break;
3986+
case GGML_TYPE_Q5_0:
3987+
{
3988+
switch (dstt) {
3989+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3990+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3991+
default: GGML_ABORT("not implemented");
3992+
};
3993+
} break;
3994+
case GGML_TYPE_Q5_1:
3995+
{
3996+
switch (dstt) {
3997+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3998+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
3999+
default: GGML_ABORT("not implemented");
4000+
};
4001+
} break;
4002+
case GGML_TYPE_Q8_0:
4003+
{
4004+
switch (dstt) {
4005+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4006+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4007+
default: GGML_ABORT("not implemented");
39404008
};
39414009
} break;
39424010
default: GGML_ABORT("not implemented");
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
39664034
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
39674035
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
39684036

4037+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4038+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4039+
39694040
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4041+
39704042
} break;
39714043
case GGML_OP_SET:
39724044
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
43414341
}
43424342
}
43434343

4344+
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
4345+
kernel void kernel_cpy_q_f32(
4346+
constant ggml_metal_kargs_cpy & args,
4347+
device const char * src0,
4348+
device char * dst,
4349+
uint3 tgpig[[threadgroup_position_in_grid]],
4350+
ushort3 tpitg[[thread_position_in_threadgroup]],
4351+
ushort3 ntg[[threads_per_threadgroup]]) {
4352+
const int i03 = tgpig[2];
4353+
const int i02 = tgpig[1];
4354+
const int i01 = tgpig[0];
4355+
4356+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4357+
4358+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
4359+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
4360+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
4361+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
4362+
4363+
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4364+
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4365+
4366+
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
4367+
T4x4 temp;
4368+
dequantize_func(src_data + i00/nl, i00%nl, temp);
4369+
dst_data[i00] = temp;
4370+
}
4371+
}
4372+
4373+
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
4374+
4375+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
4376+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
4377+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
4378+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
4379+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
4380+
4381+
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
4382+
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
4383+
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
4384+
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
4385+
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
4386+
43444387
kernel void kernel_concat(
43454388
constant ggml_metal_kargs_concat & args,
43464389
device const char * src0,

0 commit comments

Comments
 (0)