Skip to content

Commit dcb7430

Browse files
vulkan: implement GGML_OP_RMS_NORM_BACK
1 parent 8f5dedc commit dcb7430

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ struct vk_device_struct {
240240
vk_pipeline pipeline_norm_f32;
241241
vk_pipeline pipeline_group_norm_f32;
242242
vk_pipeline pipeline_rms_norm_f32;
243+
vk_pipeline pipeline_rms_norm_back_f32;
243244
vk_pipeline pipeline_gelu_f32;
244245
vk_pipeline pipeline_gelu_quick_f32;
245246
vk_pipeline pipeline_silu_f32;
@@ -2118,6 +2119,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21182119
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21192120
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21202121
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2122+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
21212123

21222124
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21232125
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5270,6 +5272,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52705272
return ctx->device->pipeline_rms_norm_f32;
52715273
}
52725274
return nullptr;
5275+
case GGML_OP_RMS_NORM_BACK:
5276+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5277+
return ctx->device->pipeline_rms_norm_back_f32;
5278+
}
5279+
return nullptr;
52735280
case GGML_OP_UNARY:
52745281
switch (ggml_get_unary_op(dst)) {
52755282
case GGML_UNARY_OP_SILU:
@@ -5627,6 +5634,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56275634
switch (op) {
56285635
case GGML_OP_NORM:
56295636
case GGML_OP_RMS_NORM:
5637+
case GGML_OP_RMS_NORM_BACK:
56305638
case GGML_OP_SOFT_MAX:
56315639
case GGML_OP_SUM_ROWS:
56325640
{
@@ -6144,6 +6152,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
61446152
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
61456153
}
61466154

6155+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6156+
float * op_params = (float *)dst->op_params;
6157+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6158+
}
6159+
61476160
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
61486161
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
61496162
}
@@ -7117,6 +7130,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71177130
case GGML_OP_NORM:
71187131
case GGML_OP_GROUP_NORM:
71197132
case GGML_OP_RMS_NORM:
7133+
case GGML_OP_RMS_NORM_BACK:
71207134
case GGML_OP_DIAG_MASK_INF:
71217135
case GGML_OP_SOFT_MAX:
71227136
case GGML_OP_ROPE:
@@ -7170,6 +7184,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71707184
case GGML_OP_NORM:
71717185
case GGML_OP_GROUP_NORM:
71727186
case GGML_OP_RMS_NORM:
7187+
case GGML_OP_RMS_NORM_BACK:
71737188
case GGML_OP_UNARY:
71747189
case GGML_OP_DIAG_MASK_INF:
71757190
case GGML_OP_SOFT_MAX:
@@ -7267,6 +7282,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72677282
case GGML_OP_RMS_NORM:
72687283
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
72697284

7285+
break;
7286+
case GGML_OP_RMS_NORM_BACK:
7287+
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7288+
72707289
break;
72717290
case GGML_OP_UNARY:
72727291
switch (ggml_get_unary_op(node)) {
@@ -7405,6 +7424,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74057424
case GGML_OP_NORM:
74067425
case GGML_OP_GROUP_NORM:
74077426
case GGML_OP_RMS_NORM:
7427+
case GGML_OP_RMS_NORM_BACK:
74087428
case GGML_OP_DIAG_MASK_INF:
74097429
case GGML_OP_SOFT_MAX:
74107430
case GGML_OP_ROPE:
@@ -8327,6 +8347,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83278347
case GGML_OP_MUL:
83288348
case GGML_OP_DIV:
83298349
case GGML_OP_CONCAT:
8350+
case GGML_OP_RMS_NORM_BACK:
83308351
case GGML_OP_UPSCALE:
83318352
case GGML_OP_SCALE:
83328353
case GGML_OP_SQR:
@@ -8850,6 +8871,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88508871
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
88518872
} else if (tensor->op == GGML_OP_RMS_NORM) {
88528873
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
8874+
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
8875+
const float eps = ((float *) tensor->op_params)[0];
8876+
tensor_clone = ggml_rms_norm_back(ggml_ctx, src0_clone, src1_clone, eps);
88538877
} else if (tensor->op == GGML_OP_SOFT_MAX) {
88548878
if (src1 != nullptr) {
88558879
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
#define BLOCK_SIZE 512
8+
9+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10+
11+
layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
12+
layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
13+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
14+
15+
shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
16+
shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
17+
18+
void main() {
19+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
20+
const uint tid = gl_LocalInvocationID.x;
21+
22+
// Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
23+
24+
// partial sums for thread in warp
25+
sum_xx[tid] = FLOAT_TYPE(0.0f);
26+
sum_xg[tid] = FLOAT_TYPE(0.0f);
27+
28+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
29+
const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
30+
const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
31+
sum_xx[tid] += xi * xi;
32+
sum_xg[tid] += xi * gi;
33+
}
34+
35+
// sum up partial sums and write back result
36+
barrier();
37+
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
38+
if (tid < s) {
39+
sum_xx[tid] += sum_xx[tid + s];
40+
sum_xg[tid] += sum_xg[tid + s];
41+
}
42+
barrier();
43+
}
44+
45+
const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
46+
const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
47+
const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
48+
const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
49+
50+
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
51+
data_d[row*p.KX + col] = D_TYPE(
52+
scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
53+
scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
54+
}
55+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ void process_shaders() {
427427
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
428428
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
429429
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
430+
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
430431

431432
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
432433
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)