Skip to content

Commit 1b2b86a

Browse files
vulkan: implement GGML_OP_SOFTMAX_BACK
1 parent d3cfd8e commit 1b2b86a

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ struct vk_device_struct {
251251
vk_pipeline pipeline_diag_mask_inf_f32;
252252
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
253253
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
254+
vk_pipeline pipeline_soft_max_back_f32;
254255
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
255256
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
256257
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -2188,6 +2189,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21882189
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
21892190
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21902191
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2192+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21912193

21922194
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
21932195
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -5330,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53305332
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
53315333
}
53325334
return nullptr;
5335+
case GGML_OP_SOFT_MAX_BACK:
5336+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5337+
return ctx->device->pipeline_soft_max_back_f32;
5338+
}
5339+
return nullptr;
53335340
case GGML_OP_ROPE:
53345341
case GGML_OP_ROPE_BACK:
53355342
{
@@ -5643,6 +5650,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56435650
case GGML_OP_RMS_NORM:
56445651
case GGML_OP_RMS_NORM_BACK:
56455652
case GGML_OP_SOFT_MAX:
5653+
case GGML_OP_SOFT_MAX_BACK:
56465654
case GGML_OP_SUM_ROWS:
56475655
{
56485656
const uint32_t nr = ggml_nrows(src0);
@@ -6203,6 +6211,11 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
62036211
}, dryrun);
62046212
}
62056213

6214+
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6215+
float * op_params = (float *)dst->op_params;
6216+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6217+
}
6218+
62066219
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
62076220
const int n_dims = ((int32_t *) dst->op_params)[1];
62086221
const int mode = ((int32_t *) dst->op_params)[2];
@@ -7145,6 +7158,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71457158
case GGML_OP_RMS_NORM_BACK:
71467159
case GGML_OP_DIAG_MASK_INF:
71477160
case GGML_OP_SOFT_MAX:
7161+
case GGML_OP_SOFT_MAX_BACK:
71487162
case GGML_OP_ROPE:
71497163
case GGML_OP_ROPE_BACK:
71507164
case GGML_OP_MUL_MAT:
@@ -7201,6 +7215,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72017215
case GGML_OP_UNARY:
72027216
case GGML_OP_DIAG_MASK_INF:
72037217
case GGML_OP_SOFT_MAX:
7218+
case GGML_OP_SOFT_MAX_BACK:
72047219
case GGML_OP_ROPE:
72057220
case GGML_OP_ROPE_BACK:
72067221
case GGML_OP_ARGSORT:
@@ -7324,6 +7339,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73247339
case GGML_OP_SOFT_MAX:
73257340
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
73267341

7342+
break;
7343+
case GGML_OP_SOFT_MAX_BACK:
7344+
ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7345+
73277346
break;
73287347
case GGML_OP_ROPE:
73297348
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
@@ -7445,6 +7464,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74457464
case GGML_OP_RMS_NORM_BACK:
74467465
case GGML_OP_DIAG_MASK_INF:
74477466
case GGML_OP_SOFT_MAX:
7467+
case GGML_OP_SOFT_MAX_BACK:
74487468
case GGML_OP_ROPE:
74497469
case GGML_OP_ROPE_BACK:
74507470
case GGML_OP_RESHAPE:
@@ -8376,6 +8396,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83768396
case GGML_OP_PAD:
83778397
case GGML_OP_DIAG_MASK_INF:
83788398
case GGML_OP_SOFT_MAX:
8399+
case GGML_OP_SOFT_MAX_BACK:
83798400
case GGML_OP_ARGSORT:
83808401
case GGML_OP_SUM_ROWS:
83818402
case GGML_OP_IM2COL:
@@ -8901,6 +8922,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89018922
} else {
89028923
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
89038924
}
8925+
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
8926+
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
89048927
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
89058928
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
89068929
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#include "generic_head.comp"
6+
#include "types.comp"
7+
8+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
9+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10+
11+
// In this shader Y = softmax(X) and X is not provided as input.
12+
13+
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
14+
layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
15+
layout (binding = 2) buffer D {D_TYPE data_d[];};
16+
17+
shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
18+
19+
void main() {
20+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
21+
const uint tid = gl_LocalInvocationID.x;
22+
23+
FLOAT_TYPE scale = p.param1;
24+
25+
// partial sums for thread in warp
26+
sum_yg[tid] = FLOAT_TYPE(0.0f);
27+
28+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
29+
const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
30+
const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
31+
sum_yg[tid] += yi * gi;
32+
}
33+
34+
// sum up partial sums and write back result
35+
barrier();
36+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
37+
if (tid < s) {
38+
sum_yg[tid] += sum_yg[tid + s];
39+
}
40+
barrier();
41+
}
42+
43+
const FLOAT_TYPE dot_yg = sum_yg[0];
44+
45+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
46+
data_d[row*p.KX + col] = D_TYPE(scale
47+
* (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
48+
* FLOAT_TYPE(data_y[row*p.KX + col]));
49+
}
50+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ void process_shaders() {
484484

485485
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
486486
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
487+
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
487488

488489
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
489490
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)