@@ -143,6 +143,7 @@ struct ggml_backend_opencl_context {
143
143
cl_kernel kernel_rms_norm;
144
144
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
145
145
cl_kernel kernel_soft_max, kernel_soft_max_4;
146
+ cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
146
147
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
147
148
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
148
149
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
@@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
614
615
CL_CHECK ((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel (backend_ctx->program , " kernel_diag_mask_inf_8" , &err), err));
615
616
CL_CHECK ((backend_ctx->kernel_soft_max = clCreateKernel (backend_ctx->program , " kernel_soft_max" , &err), err));
616
617
CL_CHECK ((backend_ctx->kernel_soft_max_4 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4" , &err), err));
618
+ CL_CHECK ((backend_ctx->kernel_soft_max_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_f16" , &err), err));
619
+ CL_CHECK ((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4_f16" , &err), err));
617
620
CL_CHECK ((backend_ctx->kernel_rope_norm_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f32" , &err), err));
618
621
CL_CHECK ((backend_ctx->kernel_rope_norm_f16 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f16" , &err), err));
619
622
CL_CHECK ((backend_ctx->kernel_rope_neox_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_neox_f32" , &err), err));
@@ -1044,8 +1047,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1044
1047
return true ;
1045
1048
case GGML_OP_DIAG_MASK_INF:
1046
1049
return op->ne [3 ] == 1 ;
1047
- case GGML_OP_ROPE:
1050
+ case GGML_OP_ROPE: {
1051
+ const int mode = ((const int32_t *) op->op_params )[2 ];
1052
+ if (mode & GGML_ROPE_TYPE_MROPE) {
1053
+ return false ;
1054
+ }
1055
+ if (mode & GGML_ROPE_TYPE_VISION) {
1056
+ return false ;
1057
+ }
1048
1058
return true ;
1059
+ }
1049
1060
default :
1050
1061
return false ;
1051
1062
}
@@ -3666,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
3666
3677
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
3667
3678
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
3668
3679
3680
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
3681
+
3669
3682
// Local size must be wave size. Each workgroup is a wave, working on a row,
3670
3683
// where a row corresponds to leading dimension.
3671
3684
int nth = MIN (32 , ne00);
@@ -3683,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
3683
3696
cl_kernel kernel;
3684
3697
3685
3698
if (ne00%4 == 0 ) {
3686
- kernel = backend_ctx->kernel_soft_max_4 ;
3699
+ if (use_f16) {
3700
+ kernel = backend_ctx->kernel_soft_max_4_f16 ;
3701
+ } else {
3702
+ kernel = backend_ctx->kernel_soft_max_4 ;
3703
+ }
3687
3704
} else {
3688
- kernel = backend_ctx->kernel_soft_max ;
3705
+ if (use_f16) {
3706
+ kernel = backend_ctx->kernel_soft_max_f16 ;
3707
+ } else {
3708
+ kernel = backend_ctx->kernel_soft_max ;
3709
+ }
3689
3710
}
3690
3711
3691
3712
CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
@@ -3766,7 +3787,8 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
3766
3787
const int nb2 = dst ? dst->nb [2 ] : 0 ;
3767
3788
const int nb3 = dst ? dst->nb [3 ] : 0 ;
3768
3789
3769
- GGML_ASSERT (ne10 == ne02);
3790
+ GGML_ASSERT (ne10 % ne02 == 0 );
3791
+ GGML_ASSERT (ne10 >= ne02);
3770
3792
3771
3793
int nth = MIN (64 , ne00);
3772
3794
0 commit comments