@@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407
407
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408
408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409
409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
410
420
GGML_METAL_KERNEL_TYPE_CONCAT,
411
421
GGML_METAL_KERNEL_TYPE_SQR,
412
422
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass
1012
1022
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true );
1013
1023
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true );
1014
1024
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true );
1025
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true );
1026
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true );
1027
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true );
1028
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true );
1029
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true );
1030
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true );
1031
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true );
1032
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true );
1033
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true );
1034
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true );
1015
1035
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONCAT, concat, true );
1016
1036
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQR, sqr, true );
1017
1037
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SQRT, sqrt , true );
@@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1287
1307
default :
1288
1308
return false ;
1289
1309
}
1310
+ case GGML_TYPE_Q4_0:
1311
+ case GGML_TYPE_Q4_1:
1312
+ case GGML_TYPE_Q5_0:
1313
+ case GGML_TYPE_Q5_1:
1314
+ case GGML_TYPE_Q8_0:
1315
+ switch (op->type ) {
1316
+ case GGML_TYPE_F32:
1317
+ case GGML_TYPE_F16:
1318
+ return true ;
1319
+ default :
1320
+ return false ;
1321
+ }
1290
1322
default :
1291
1323
return false ;
1292
1324
};
@@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node(
3899
3931
case GGML_OP_CPY:
3900
3932
case GGML_OP_CONT:
3901
3933
{
3902
- GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
3903
-
3904
- int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
3905
-
3906
3934
id <MTLComputePipelineState > pipeline = nil ;
3907
3935
3908
3936
switch (src0t) {
@@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node(
3936
3964
switch (dstt) {
3937
3965
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
3938
3966
case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline ; break ;
3939
- default : GGML_ASSERT (false && " not implemented" );
3967
+ default : GGML_ABORT (" not implemented" );
3968
+ };
3969
+ } break ;
3970
+ case GGML_TYPE_Q4_0:
3971
+ {
3972
+ switch (dstt) {
3973
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline ; break ;
3974
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline ; break ;
3975
+ default : GGML_ABORT (" not implemented" );
3976
+ };
3977
+ } break ;
3978
+ case GGML_TYPE_Q4_1:
3979
+ {
3980
+ switch (dstt) {
3981
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline ; break ;
3982
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline ; break ;
3983
+ default : GGML_ABORT (" not implemented" );
3984
+ };
3985
+ } break ;
3986
+ case GGML_TYPE_Q5_0:
3987
+ {
3988
+ switch (dstt) {
3989
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline ; break ;
3990
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline ; break ;
3991
+ default : GGML_ABORT (" not implemented" );
3992
+ };
3993
+ } break ;
3994
+ case GGML_TYPE_Q5_1:
3995
+ {
3996
+ switch (dstt) {
3997
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline ; break ;
3998
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline ; break ;
3999
+ default : GGML_ABORT (" not implemented" );
4000
+ };
4001
+ } break ;
4002
+ case GGML_TYPE_Q8_0:
4003
+ {
4004
+ switch (dstt) {
4005
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline ; break ;
4006
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline ; break ;
4007
+ default : GGML_ABORT (" not implemented" );
3940
4008
};
3941
4009
} break ;
3942
4010
default : GGML_ABORT (" not implemented" );
@@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node(
3966
4034
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3967
4035
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3968
4036
4037
+ GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
4038
+ int nth = MIN (1024 , ne00/ggml_blck_size (src0->type ));
4039
+
3969
4040
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4041
+
3970
4042
} break ;
3971
4043
case GGML_OP_SET:
3972
4044
{
0 commit comments