Skip to content

Commit 1215ed7

Browse files
CUDA: Implemented row flattening for non-glm RoPE (ggml-org#2468)
1 parent 2dbf518 commit 1215ed7

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

ggml-cuda.cu

+15-8
Original file line numberDiff line numberDiff line change
@@ -3150,7 +3150,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
31503150
}
31513151

31523152
// rope == RoPE == rotary positional embedding
3153-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
3153+
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
3154+
const float p_delta, const int p_delta_rows, const float theta_scale) {
31543155
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
31553156

31563157
if (col >= ncols) {
@@ -3160,7 +3161,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
31603161
const int row = blockDim.y*blockIdx.y + threadIdx.y;
31613162
const int i = row*ncols + col;
31623163

3163-
const float theta = p*powf(theta_scale, col/2);
3164+
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
31643165
const float sin_theta = sinf(theta);
31653166
const float cos_theta = cosf(theta);
31663167

@@ -3764,12 +3765,13 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
37643765
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
37653766
}
37663767

3767-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
3768+
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
3769+
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
37683770
GGML_ASSERT(nrows % 2 == 0);
37693771
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
37703772
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
37713773
const dim3 block_nums(num_blocks_x, nrows, 1);
3772-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
3774+
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
37733775
}
37743776

37753777
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -4465,6 +4467,7 @@ inline void ggml_cuda_op_rope(
44654467
GGML_ASSERT(dst_ddf_i != nullptr);
44664468

44674469
const int64_t ne00 = src0->ne[0];
4470+
const int64_t ne01 = src0->ne[1];
44684471
const int64_t i01_diff = i01_high - i01_low;
44694472

44704473
const int n_past = ((int32_t *) dst->op_params)[0];
@@ -4478,17 +4481,18 @@ inline void ggml_cuda_op_rope(
44784481
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
44794482

44804483
const float theta_scale = powf(freq_base, -2.0f/n_dims);
4481-
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
44824484

4483-
bool is_glm = mode & 4;
4485+
const bool is_glm = mode & 4;
44844486

44854487
// compute
44864488
if (is_glm) {
4489+
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
44874490
const float id_p = min(p, n_ctx - 2.f);
44884491
const float block_p = max(p - (n_ctx - 2.f), 0.f);
44894492
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
44904493
} else {
4491-
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
4494+
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
4495+
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
44924496
}
44934497

44944498
(void) src1;
@@ -5103,7 +5107,10 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml
51035107

51045108
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
51055109
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
5106-
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
5110+
5111+
const int mode = ((int32_t *) dst->op_params)[2];
5112+
const bool is_glm = mode & 4;
5113+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
51075114
}
51085115

51095116
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)