@@ -244,6 +244,7 @@ struct vk_device_struct {
244
244
vk_pipeline pipeline_gelu_f32;
245
245
vk_pipeline pipeline_gelu_quick_f32;
246
246
vk_pipeline pipeline_silu_f32;
247
+ vk_pipeline pipeline_silu_back_f32;
247
248
vk_pipeline pipeline_relu_f32;
248
249
vk_pipeline pipeline_leaky_relu_f32;
249
250
vk_pipeline pipeline_tanh_f32;
@@ -2176,6 +2177,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2176
2177
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2177
2178
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2178
2179
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2180
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2179
2181
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2180
2182
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2181
2183
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -5257,6 +5259,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5257
5259
case GGML_OP_CONT:
5258
5260
case GGML_OP_DUP:
5259
5261
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5262
+ case GGML_OP_SILU_BACK:
5263
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5264
+ return ctx->device->pipeline_silu_back_f32;
5265
+ }
5266
+ return nullptr;
5260
5267
case GGML_OP_NORM:
5261
5268
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5262
5269
return ctx->device->pipeline_norm_f32;
@@ -6130,6 +6137,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6130
6137
}, dryrun);
6131
6138
}
6132
6139
6140
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6141
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6142
+ }
6143
+
6133
6144
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6134
6145
float * op_params = (float *)dst->op_params;
6135
6146
@@ -7127,6 +7138,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7127
7138
case GGML_OP_CPY:
7128
7139
case GGML_OP_CONT:
7129
7140
case GGML_OP_DUP:
7141
+ case GGML_OP_SILU_BACK:
7130
7142
case GGML_OP_NORM:
7131
7143
case GGML_OP_GROUP_NORM:
7132
7144
case GGML_OP_RMS_NORM:
@@ -7181,6 +7193,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7181
7193
case GGML_OP_CPY:
7182
7194
case GGML_OP_CONT:
7183
7195
case GGML_OP_DUP:
7196
+ case GGML_OP_SILU_BACK:
7184
7197
case GGML_OP_NORM:
7185
7198
case GGML_OP_GROUP_NORM:
7186
7199
case GGML_OP_RMS_NORM:
@@ -7270,6 +7283,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7270
7283
case GGML_OP_DUP:
7271
7284
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7272
7285
7286
+ break;
7287
+ case GGML_OP_SILU_BACK:
7288
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7289
+
7273
7290
break;
7274
7291
case GGML_OP_NORM:
7275
7292
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7421,6 +7438,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7421
7438
case GGML_OP_CPY:
7422
7439
case GGML_OP_CONT:
7423
7440
case GGML_OP_DUP:
7441
+ case GGML_OP_SILU_BACK:
7424
7442
case GGML_OP_NORM:
7425
7443
case GGML_OP_GROUP_NORM:
7426
7444
case GGML_OP_RMS_NORM:
@@ -8347,6 +8365,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8347
8365
case GGML_OP_MUL:
8348
8366
case GGML_OP_DIV:
8349
8367
case GGML_OP_CONCAT:
8368
+ case GGML_OP_SILU_BACK:
8350
8369
case GGML_OP_RMS_NORM_BACK:
8351
8370
case GGML_OP_UPSCALE:
8352
8371
case GGML_OP_SCALE:
@@ -8874,6 +8893,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8874
8893
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
8875
8894
const float eps = ((float *) tensor->op_params)[0];
8876
8895
tensor_clone = ggml_rms_norm_back(ggml_ctx, src0_clone, src1_clone, eps);
8896
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
8897
+ tensor_clone = ggml_silu_back(ggml_ctx, src0_clone, src1_clone);
8877
8898
} else if (tensor->op == GGML_OP_SOFT_MAX) {
8878
8899
if (src1 != nullptr) {
8879
8900
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
0 commit comments