@@ -252,6 +252,7 @@ struct vk_device_struct {
252
252
vk_pipeline pipeline_diag_mask_inf_f32;
253
253
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
254
254
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
255
+ vk_pipeline pipeline_soft_max_back_f32;
255
256
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
256
257
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
257
258
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -2195,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2195
2196
ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_wg512 , " soft_max_f32_wg512" , soft_max_f32_len, soft_max_f32_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
2196
2197
ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16 , " soft_max_f32_f16" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2197
2198
ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
2199
+ ggml_vk_create_pipeline (device, device->pipeline_soft_max_back_f32 , " soft_max_back_f32" , soft_max_back_f32_len, soft_max_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2198
2200
2199
2201
ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
2200
2202
ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
@@ -5359,6 +5361,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5359
5361
return src0->ne [0 ] > 1024 ? ctx->device ->pipeline_soft_max_f32_f16_wg512 : ctx->device ->pipeline_soft_max_f32_f16 ;
5360
5362
}
5361
5363
return nullptr ;
5364
+ case GGML_OP_SOFT_MAX_BACK:
5365
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5366
+ return ctx->device ->pipeline_soft_max_back_f32 ;
5367
+ }
5368
+ return nullptr ;
5362
5369
case GGML_OP_ROPE:
5363
5370
case GGML_OP_ROPE_BACK:
5364
5371
{
@@ -5690,6 +5697,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5690
5697
case GGML_OP_RMS_NORM:
5691
5698
case GGML_OP_RMS_NORM_BACK:
5692
5699
case GGML_OP_SOFT_MAX:
5700
+ case GGML_OP_SOFT_MAX_BACK:
5693
5701
case GGML_OP_SUM_ROWS:
5694
5702
case GGML_OP_ARGMAX:
5695
5703
{
@@ -6397,6 +6405,11 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6397
6405
}, dryrun);
6398
6406
}
6399
6407
6408
+ static void ggml_vk_soft_max_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6409
+ float * op_params = (float *)dst->op_params ;
6410
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], op_params[0 ], op_params[1 ] }, dryrun);
6411
+ }
6412
+
6400
6413
static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false ) {
6401
6414
const int n_dims = ((int32_t *) dst->op_params )[1 ];
6402
6415
const int mode = ((int32_t *) dst->op_params )[2 ];
@@ -7353,6 +7366,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7353
7366
case GGML_OP_RMS_NORM_BACK:
7354
7367
case GGML_OP_DIAG_MASK_INF:
7355
7368
case GGML_OP_SOFT_MAX:
7369
+ case GGML_OP_SOFT_MAX_BACK:
7356
7370
case GGML_OP_ROPE:
7357
7371
case GGML_OP_ROPE_BACK:
7358
7372
case GGML_OP_MUL_MAT:
@@ -7415,6 +7429,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7415
7429
case GGML_OP_UNARY:
7416
7430
case GGML_OP_DIAG_MASK_INF:
7417
7431
case GGML_OP_SOFT_MAX:
7432
+ case GGML_OP_SOFT_MAX_BACK:
7418
7433
case GGML_OP_ROPE:
7419
7434
case GGML_OP_ROPE_BACK:
7420
7435
case GGML_OP_ARGSORT:
@@ -7549,6 +7564,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7549
7564
case GGML_OP_SOFT_MAX:
7550
7565
ggml_vk_soft_max (ctx, compute_ctx, src0, src1, node, dryrun);
7551
7566
7567
+ break ;
7568
+ case GGML_OP_SOFT_MAX_BACK:
7569
+ ggml_vk_soft_max_back (ctx, compute_ctx, src0, src1, node, dryrun);
7570
+
7552
7571
break ;
7553
7572
case GGML_OP_ROPE:
7554
7573
ggml_vk_rope (ctx, compute_ctx, src0, src1, src2, node, false , dryrun);
@@ -7688,6 +7707,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7688
7707
case GGML_OP_RMS_NORM_BACK:
7689
7708
case GGML_OP_DIAG_MASK_INF:
7690
7709
case GGML_OP_SOFT_MAX:
7710
+ case GGML_OP_SOFT_MAX_BACK:
7691
7711
case GGML_OP_ROPE:
7692
7712
case GGML_OP_ROPE_BACK:
7693
7713
case GGML_OP_RESHAPE:
@@ -8636,6 +8656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8636
8656
case GGML_OP_PAD:
8637
8657
case GGML_OP_DIAG_MASK_INF:
8638
8658
case GGML_OP_SOFT_MAX:
8659
+ case GGML_OP_SOFT_MAX_BACK:
8639
8660
case GGML_OP_ARGSORT:
8640
8661
case GGML_OP_SUM:
8641
8662
case GGML_OP_SUM_ROWS:
@@ -9038,6 +9059,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9038
9059
} else {
9039
9060
tensor_clone = ggml_soft_max (ggml_ctx, src_clone[0 ]);
9040
9061
}
9062
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9063
+ tensor_clone = ggml_soft_max_ext_back (ggml_ctx, src_clone[0 ], src_clone[1 ], ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
9041
9064
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9042
9065
tensor_clone = ggml_diag_mask_inf (ggml_ctx, src_clone[0 ], *(int *)tensor->op_params );
9043
9066
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
0 commit comments