@@ -251,6 +251,7 @@ struct vk_device_struct {
251
251
vk_pipeline pipeline_diag_mask_inf_f32;
252
252
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
253
253
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
254
+ vk_pipeline pipeline_soft_max_back_f32;
254
255
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
255
256
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
256
257
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -2188,6 +2189,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2188
2189
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2189
2190
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2190
2191
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2192
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2191
2193
2192
2194
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2193
2195
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -5330,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5330
5332
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5331
5333
}
5332
5334
return nullptr;
5335
+ case GGML_OP_SOFT_MAX_BACK:
5336
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5337
+ return ctx->device->pipeline_soft_max_back_f32;
5338
+ }
5339
+ return nullptr;
5333
5340
case GGML_OP_ROPE:
5334
5341
case GGML_OP_ROPE_BACK:
5335
5342
{
@@ -5643,6 +5650,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5643
5650
case GGML_OP_RMS_NORM:
5644
5651
case GGML_OP_RMS_NORM_BACK:
5645
5652
case GGML_OP_SOFT_MAX:
5653
+ case GGML_OP_SOFT_MAX_BACK:
5646
5654
case GGML_OP_SUM_ROWS:
5647
5655
{
5648
5656
const uint32_t nr = ggml_nrows(src0);
@@ -6203,6 +6211,11 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6203
6211
}, dryrun);
6204
6212
}
6205
6213
6214
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6215
+ float * op_params = (float *)dst->op_params;
6216
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6217
+ }
6218
+
6206
6219
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6207
6220
const int n_dims = ((int32_t *) dst->op_params)[1];
6208
6221
const int mode = ((int32_t *) dst->op_params)[2];
@@ -7145,6 +7158,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7145
7158
case GGML_OP_RMS_NORM_BACK:
7146
7159
case GGML_OP_DIAG_MASK_INF:
7147
7160
case GGML_OP_SOFT_MAX:
7161
+ case GGML_OP_SOFT_MAX_BACK:
7148
7162
case GGML_OP_ROPE:
7149
7163
case GGML_OP_ROPE_BACK:
7150
7164
case GGML_OP_MUL_MAT:
@@ -7201,6 +7215,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7201
7215
case GGML_OP_UNARY:
7202
7216
case GGML_OP_DIAG_MASK_INF:
7203
7217
case GGML_OP_SOFT_MAX:
7218
+ case GGML_OP_SOFT_MAX_BACK:
7204
7219
case GGML_OP_ROPE:
7205
7220
case GGML_OP_ROPE_BACK:
7206
7221
case GGML_OP_ARGSORT:
@@ -7324,6 +7339,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7324
7339
case GGML_OP_SOFT_MAX:
7325
7340
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7326
7341
7342
+ break;
7343
+ case GGML_OP_SOFT_MAX_BACK:
7344
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7345
+
7327
7346
break;
7328
7347
case GGML_OP_ROPE:
7329
7348
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
@@ -7445,6 +7464,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7445
7464
case GGML_OP_RMS_NORM_BACK:
7446
7465
case GGML_OP_DIAG_MASK_INF:
7447
7466
case GGML_OP_SOFT_MAX:
7467
+ case GGML_OP_SOFT_MAX_BACK:
7448
7468
case GGML_OP_ROPE:
7449
7469
case GGML_OP_ROPE_BACK:
7450
7470
case GGML_OP_RESHAPE:
@@ -8376,6 +8396,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8376
8396
case GGML_OP_PAD:
8377
8397
case GGML_OP_DIAG_MASK_INF:
8378
8398
case GGML_OP_SOFT_MAX:
8399
+ case GGML_OP_SOFT_MAX_BACK:
8379
8400
case GGML_OP_ARGSORT:
8380
8401
case GGML_OP_SUM_ROWS:
8381
8402
case GGML_OP_IM2COL:
@@ -8901,6 +8922,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8901
8922
} else {
8902
8923
tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
8903
8924
}
8925
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
8926
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8904
8927
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8905
8928
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
8906
8929
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
0 commit comments