Skip to content

Commit 8f5dedc

Browse files
vulkan: implement GGML_OP_ROPE_BACK
1 parent 0f2bbe6 commit 8f5dedc

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ struct vk_op_rope_push_constants {
500500
uint32_t s1;
501501
uint32_t s2;
502502
int32_t sections[4];
503+
uint32_t is_back;
503504
};
504505

505506
struct vk_op_soft_max_push_constants {
@@ -5316,6 +5317,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53165317
}
53175318
return nullptr;
53185319
case GGML_OP_ROPE:
5320+
case GGML_OP_ROPE_BACK:
53195321
{
53205322
const int mode = ((const int32_t *) dst->op_params)[2];
53215323
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5644,6 +5646,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56445646
} break;
56455647
case GGML_OP_DIAG_MASK_INF:
56465648
case GGML_OP_ROPE:
5649+
case GGML_OP_ROPE_BACK:
56475650
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
56485651
break;
56495652
case GGML_OP_GET_ROWS:
@@ -5737,7 +5740,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57375740

57385741
ggml_vk_sync_buffers(subctx);
57395742
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5740-
} else if (op == GGML_OP_ROPE) {
5743+
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
57415744
// Empty src2 is possible in rope, but the shader needs a buffer
57425745
vk_subbuffer subbuf_z;
57435746
if (use_src2) {
@@ -6176,7 +6179,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
61766179
}, dryrun);
61776180
}
61786181

6179-
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6182+
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
61806183
const int n_dims = ((int32_t *) dst->op_params)[1];
61816184
const int mode = ((int32_t *) dst->op_params)[2];
61826185
// const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6204,7 +6207,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
62046207
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
62056208
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
62066209
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6207-
sections[0], sections[1], sections[2], sections[3],
6210+
sections[0], sections[1], sections[2], sections[3], backprop
62086211
}, dryrun);
62096212
}
62106213

@@ -7117,6 +7120,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71177120
case GGML_OP_DIAG_MASK_INF:
71187121
case GGML_OP_SOFT_MAX:
71197122
case GGML_OP_ROPE:
7123+
case GGML_OP_ROPE_BACK:
71207124
case GGML_OP_MUL_MAT:
71217125
case GGML_OP_MUL_MAT_ID:
71227126
case GGML_OP_ARGSORT:
@@ -7170,6 +7174,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71707174
case GGML_OP_DIAG_MASK_INF:
71717175
case GGML_OP_SOFT_MAX:
71727176
case GGML_OP_ROPE:
7177+
case GGML_OP_ROPE_BACK:
71737178
case GGML_OP_ARGSORT:
71747179
case GGML_OP_SUM_ROWS:
71757180
case GGML_OP_IM2COL:
@@ -7285,7 +7290,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72857290

72867291
break;
72877292
case GGML_OP_ROPE:
7288-
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7293+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7294+
7295+
break;
7296+
case GGML_OP_ROPE_BACK:
7297+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
72897298

72907299
break;
72917300
case GGML_OP_ARGSORT:
@@ -7399,6 +7408,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
73997408
case GGML_OP_DIAG_MASK_INF:
74007409
case GGML_OP_SOFT_MAX:
74017410
case GGML_OP_ROPE:
7411+
case GGML_OP_ROPE_BACK:
74027412
case GGML_OP_RESHAPE:
74037413
case GGML_OP_VIEW:
74047414
case GGML_OP_PERMUTE:
@@ -8301,6 +8311,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83018311
case GGML_OP_REPEAT:
83028312
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
83038313
case GGML_OP_ROPE:
8314+
case GGML_OP_ROPE_BACK:
83048315
case GGML_OP_NONE:
83058316
case GGML_OP_RESHAPE:
83068317
case GGML_OP_VIEW:
@@ -8847,7 +8858,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88478858
}
88488859
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
88498860
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
8850-
} else if (tensor->op == GGML_OP_ROPE) {
8861+
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
88518862
const int n_dims = ((int32_t *) tensor->op_params)[1];
88528863
const int mode = ((int32_t *) tensor->op_params)[2];
88538864
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8860,9 +8871,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88608871
const float beta_slow = ((float *) tensor->op_params)[10];
88618872
if (mode & GGML_ROPE_TYPE_MROPE) {
88628873
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
8863-
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8874+
if (tensor->op == GGML_OP_ROPE) {
8875+
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8876+
} else {
8877+
tensor_clone = ggml_rope_multi_back(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8878+
}
88648879
} else {
8865-
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8880+
if (tensor->op == GGML_OP_ROPE) {
8881+
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8882+
} else {
8883+
tensor_clone = ggml_rope_ext_back(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8884+
}
88668885
}
88678886
} else if (tensor->op == GGML_OP_UNARY) {
88688887
switch (ggml_get_unary_op(tensor)) {

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
2929
uint s1;
3030
uint s2;
3131
int sections[4];
32+
uint is_back;
3233
} p;
3334

3435
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
4849
// Get n-d magnitude scaling corrected for interpolation
4950
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
5051
}
52+
// Backprogagation uses inverted rotation
53+
if (p.is_back != 0) {
54+
theta = -theta;
55+
}
5156
cos_theta = cos(theta) * mscale;
5257
sin_theta = sin(theta) * mscale;
5358
}

0 commit comments

Comments
 (0)