@@ -240,6 +240,7 @@ struct vk_device_struct {
240
240
vk_pipeline pipeline_norm_f32;
241
241
vk_pipeline pipeline_group_norm_f32;
242
242
vk_pipeline pipeline_rms_norm_f32;
243
+ vk_pipeline pipeline_rms_norm_back_f32;
243
244
vk_pipeline pipeline_gelu_f32;
244
245
vk_pipeline pipeline_gelu_quick_f32;
245
246
vk_pipeline pipeline_silu_f32;
@@ -2118,6 +2119,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2118
2119
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2119
2120
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2120
2121
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2122
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2121
2123
2122
2124
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2123
2125
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5270,6 +5272,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5270
5272
return ctx->device->pipeline_rms_norm_f32;
5271
5273
}
5272
5274
return nullptr;
5275
+ case GGML_OP_RMS_NORM_BACK:
5276
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5277
+ return ctx->device->pipeline_rms_norm_back_f32;
5278
+ }
5279
+ return nullptr;
5273
5280
case GGML_OP_UNARY:
5274
5281
switch (ggml_get_unary_op(dst)) {
5275
5282
case GGML_UNARY_OP_SILU:
@@ -5627,6 +5634,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5627
5634
switch (op) {
5628
5635
case GGML_OP_NORM:
5629
5636
case GGML_OP_RMS_NORM:
5637
+ case GGML_OP_RMS_NORM_BACK:
5630
5638
case GGML_OP_SOFT_MAX:
5631
5639
case GGML_OP_SUM_ROWS:
5632
5640
{
@@ -6144,6 +6152,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6144
6152
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6145
6153
}
6146
6154
6155
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6156
+ float * op_params = (float *)dst->op_params;
6157
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6158
+ }
6159
+
6147
6160
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6148
6161
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6149
6162
}
@@ -7117,6 +7130,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7117
7130
case GGML_OP_NORM:
7118
7131
case GGML_OP_GROUP_NORM:
7119
7132
case GGML_OP_RMS_NORM:
7133
+ case GGML_OP_RMS_NORM_BACK:
7120
7134
case GGML_OP_DIAG_MASK_INF:
7121
7135
case GGML_OP_SOFT_MAX:
7122
7136
case GGML_OP_ROPE:
@@ -7170,6 +7184,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7170
7184
case GGML_OP_NORM:
7171
7185
case GGML_OP_GROUP_NORM:
7172
7186
case GGML_OP_RMS_NORM:
7187
+ case GGML_OP_RMS_NORM_BACK:
7173
7188
case GGML_OP_UNARY:
7174
7189
case GGML_OP_DIAG_MASK_INF:
7175
7190
case GGML_OP_SOFT_MAX:
@@ -7267,6 +7282,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7267
7282
case GGML_OP_RMS_NORM:
7268
7283
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7269
7284
7285
+ break;
7286
+ case GGML_OP_RMS_NORM_BACK:
7287
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7288
+
7270
7289
break;
7271
7290
case GGML_OP_UNARY:
7272
7291
switch (ggml_get_unary_op(node)) {
@@ -7405,6 +7424,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7405
7424
case GGML_OP_NORM:
7406
7425
case GGML_OP_GROUP_NORM:
7407
7426
case GGML_OP_RMS_NORM:
7427
+ case GGML_OP_RMS_NORM_BACK:
7408
7428
case GGML_OP_DIAG_MASK_INF:
7409
7429
case GGML_OP_SOFT_MAX:
7410
7430
case GGML_OP_ROPE:
@@ -8327,6 +8347,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8327
8347
case GGML_OP_MUL:
8328
8348
case GGML_OP_DIV:
8329
8349
case GGML_OP_CONCAT:
8350
+ case GGML_OP_RMS_NORM_BACK:
8330
8351
case GGML_OP_UPSCALE:
8331
8352
case GGML_OP_SCALE:
8332
8353
case GGML_OP_SQR:
@@ -8850,6 +8871,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8850
8871
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8851
8872
} else if (tensor->op == GGML_OP_RMS_NORM) {
8852
8873
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
8874
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
8875
+ const float eps = ((float *) tensor->op_params)[0];
8876
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src0_clone, src1_clone, eps);
8853
8877
} else if (tensor->op == GGML_OP_SOFT_MAX) {
8854
8878
if (src1 != nullptr) {
8855
8879
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
0 commit comments