Skip to content

Commit d3cfd8e

Browse files
vulkan: implement GGML_OP_SILU_BACK
1 parent dcb7430 commit d3cfd8e

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ struct vk_device_struct {
244244
vk_pipeline pipeline_gelu_f32;
245245
vk_pipeline pipeline_gelu_quick_f32;
246246
vk_pipeline pipeline_silu_f32;
247+
vk_pipeline pipeline_silu_back_f32;
247248
vk_pipeline pipeline_relu_f32;
248249
vk_pipeline pipeline_leaky_relu_f32;
249250
vk_pipeline pipeline_tanh_f32;
@@ -2176,6 +2177,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21762177
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21772178
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21782179
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2180+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21792181
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21802182
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21812183
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -5257,6 +5259,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52575259
case GGML_OP_CONT:
52585260
case GGML_OP_DUP:
52595261
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5262+
case GGML_OP_SILU_BACK:
5263+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5264+
return ctx->device->pipeline_silu_back_f32;
5265+
}
5266+
return nullptr;
52605267
case GGML_OP_NORM:
52615268
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
52625269
return ctx->device->pipeline_norm_f32;
@@ -6130,6 +6137,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
61306137
}, dryrun);
61316138
}
61326139

6140+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6141+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6142+
}
6143+
61336144
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
61346145
float * op_params = (float *)dst->op_params;
61356146

@@ -7127,6 +7138,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71277138
case GGML_OP_CPY:
71287139
case GGML_OP_CONT:
71297140
case GGML_OP_DUP:
7141+
case GGML_OP_SILU_BACK:
71307142
case GGML_OP_NORM:
71317143
case GGML_OP_GROUP_NORM:
71327144
case GGML_OP_RMS_NORM:
@@ -7181,6 +7193,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71817193
case GGML_OP_CPY:
71827194
case GGML_OP_CONT:
71837195
case GGML_OP_DUP:
7196+
case GGML_OP_SILU_BACK:
71847197
case GGML_OP_NORM:
71857198
case GGML_OP_GROUP_NORM:
71867199
case GGML_OP_RMS_NORM:
@@ -7270,6 +7283,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72707283
case GGML_OP_DUP:
72717284
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
72727285

7286+
break;
7287+
case GGML_OP_SILU_BACK:
7288+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7289+
72737290
break;
72747291
case GGML_OP_NORM:
72757292
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7421,6 +7438,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74217438
case GGML_OP_CPY:
74227439
case GGML_OP_CONT:
74237440
case GGML_OP_DUP:
7441+
case GGML_OP_SILU_BACK:
74247442
case GGML_OP_NORM:
74257443
case GGML_OP_GROUP_NORM:
74267444
case GGML_OP_RMS_NORM:
@@ -8347,6 +8365,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83478365
case GGML_OP_MUL:
83488366
case GGML_OP_DIV:
83498367
case GGML_OP_CONCAT:
8368+
case GGML_OP_SILU_BACK:
83508369
case GGML_OP_RMS_NORM_BACK:
83518370
case GGML_OP_UPSCALE:
83528371
case GGML_OP_SCALE:
@@ -8874,6 +8893,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88748893
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
88758894
const float eps = ((float *) tensor->op_params)[0];
88768895
tensor_clone = ggml_rms_norm_back(ggml_ctx, src0_clone, src1_clone, eps);
8896+
} else if (tensor->op == GGML_OP_SILU_BACK) {
8897+
tensor_clone = ggml_silu_back(ggml_ctx, src0_clone, src1_clone);
88778898
} else if (tensor->op == GGML_OP_SOFT_MAX) {
88788899
if (src1 != nullptr) {
88798900
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
11+
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
12+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
13+
14+
void main() {
15+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
16+
17+
if (i >= p.KX) {
18+
return;
19+
}
20+
21+
// Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
22+
23+
const float xi = float(data_x[i]);
24+
const float s = 1.0f / (1.0f + exp(-xi));
25+
data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
26+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ void process_shaders() {
475475
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
476476
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
477477
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
478+
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
478479
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
479480
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
480481
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

0 commit comments

Comments
 (0)