@@ -500,6 +500,7 @@ struct vk_op_rope_push_constants {
500
500
uint32_t s1;
501
501
uint32_t s2;
502
502
int32_t sections[4];
503
+ uint32_t is_back;
503
504
};
504
505
505
506
struct vk_op_soft_max_push_constants {
@@ -5316,6 +5317,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5316
5317
}
5317
5318
return nullptr;
5318
5319
case GGML_OP_ROPE:
5320
+ case GGML_OP_ROPE_BACK:
5319
5321
{
5320
5322
const int mode = ((const int32_t *) dst->op_params)[2];
5321
5323
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5644,6 +5646,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5644
5646
} break;
5645
5647
case GGML_OP_DIAG_MASK_INF:
5646
5648
case GGML_OP_ROPE:
5649
+ case GGML_OP_ROPE_BACK:
5647
5650
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5648
5651
break;
5649
5652
case GGML_OP_GET_ROWS:
@@ -5737,7 +5740,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5737
5740
5738
5741
ggml_vk_sync_buffers(subctx);
5739
5742
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5740
- } else if (op == GGML_OP_ROPE) {
5743
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK ) {
5741
5744
// Empty src2 is possible in rope, but the shader needs a buffer
5742
5745
vk_subbuffer subbuf_z;
5743
5746
if (use_src2) {
@@ -6176,7 +6179,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6176
6179
}, dryrun);
6177
6180
}
6178
6181
6179
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6182
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6180
6183
const int n_dims = ((int32_t *) dst->op_params)[1];
6181
6184
const int mode = ((int32_t *) dst->op_params)[2];
6182
6185
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6204,7 +6207,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6204
6207
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6205
6208
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6206
6209
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6207
- sections[0], sections[1], sections[2], sections[3],
6210
+ sections[0], sections[1], sections[2], sections[3], backprop
6208
6211
}, dryrun);
6209
6212
}
6210
6213
@@ -7117,6 +7120,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7117
7120
case GGML_OP_DIAG_MASK_INF:
7118
7121
case GGML_OP_SOFT_MAX:
7119
7122
case GGML_OP_ROPE:
7123
+ case GGML_OP_ROPE_BACK:
7120
7124
case GGML_OP_MUL_MAT:
7121
7125
case GGML_OP_MUL_MAT_ID:
7122
7126
case GGML_OP_ARGSORT:
@@ -7170,6 +7174,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7170
7174
case GGML_OP_DIAG_MASK_INF:
7171
7175
case GGML_OP_SOFT_MAX:
7172
7176
case GGML_OP_ROPE:
7177
+ case GGML_OP_ROPE_BACK:
7173
7178
case GGML_OP_ARGSORT:
7174
7179
case GGML_OP_SUM_ROWS:
7175
7180
case GGML_OP_IM2COL:
@@ -7285,7 +7290,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7285
7290
7286
7291
break;
7287
7292
case GGML_OP_ROPE:
7288
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7293
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7294
+
7295
+ break;
7296
+ case GGML_OP_ROPE_BACK:
7297
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7289
7298
7290
7299
break;
7291
7300
case GGML_OP_ARGSORT:
@@ -7399,6 +7408,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7399
7408
case GGML_OP_DIAG_MASK_INF:
7400
7409
case GGML_OP_SOFT_MAX:
7401
7410
case GGML_OP_ROPE:
7411
+ case GGML_OP_ROPE_BACK:
7402
7412
case GGML_OP_RESHAPE:
7403
7413
case GGML_OP_VIEW:
7404
7414
case GGML_OP_PERMUTE:
@@ -8301,6 +8311,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8301
8311
case GGML_OP_REPEAT:
8302
8312
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8303
8313
case GGML_OP_ROPE:
8314
+ case GGML_OP_ROPE_BACK:
8304
8315
case GGML_OP_NONE:
8305
8316
case GGML_OP_RESHAPE:
8306
8317
case GGML_OP_VIEW:
@@ -8847,7 +8858,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8847
8858
}
8848
8859
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8849
8860
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
8850
- } else if (tensor->op == GGML_OP_ROPE) {
8861
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK ) {
8851
8862
const int n_dims = ((int32_t *) tensor->op_params)[1];
8852
8863
const int mode = ((int32_t *) tensor->op_params)[2];
8853
8864
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8860,9 +8871,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8860
8871
const float beta_slow = ((float *) tensor->op_params)[10];
8861
8872
if (mode & GGML_ROPE_TYPE_MROPE) {
8862
8873
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
8863
- tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8874
+ if (tensor->op == GGML_OP_ROPE) {
8875
+ tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8876
+ } else {
8877
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8878
+ }
8864
8879
} else {
8865
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8880
+ if (tensor->op == GGML_OP_ROPE) {
8881
+ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8882
+ } else {
8883
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8884
+ }
8866
8885
}
8867
8886
} else if (tensor->op == GGML_OP_UNARY) {
8868
8887
switch (ggml_get_unary_op(tensor)) {
0 commit comments