Skip to content

Commit cdaa126

Browse files
author
noemotiovon
committed
[CANN]Fix the ROPE precision issue
Signed-off-by: noemotiovon <[email protected]>
1 parent 9d14689 commit cdaa126

File tree

1 file changed

+52
-64
lines changed

1 file changed

+52
-64
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

+52-64
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,6 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
145145
GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst);
146146
}
147147

148-
/**
149-
* @brief Casts the elements of a tensor to a specified data type using the CANN backend.
150-
*
151-
* @details This function performs a type conversion on the elements of the input tensor `acl_src`
152-
* and stores the results in the destination tensor `acl_dst`. The conversion type is
153-
* determined based on the `dst` tensor's data type.
154-
*
155-
* @param ctx The context for the CANN backend operations.
156-
* @param acl_src The source tensor whose elements will be cast.
157-
* @param acl_dst The destination tensor that will store the casted elements.
158-
* @param dst The ggml tensor specifying the target data type.
159-
*/
160-
static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
161-
aclTensor* acl_dst, ggml_tensor* dst) {
162-
aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
163-
}
164-
165148
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
166149
ggml_tensor* src = dst->src[0];
167150
GGML_ASSERT(ggml_can_repeat(src, dst));
@@ -768,7 +751,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
768751
if (dst->type == src0->type) {
769752
cann_copy(ctx, acl_src, acl_dst);
770753
} else {
771-
aclnn_cast(ctx, acl_src, acl_dst, dst);
754+
aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
772755
}
773756
} else {
774757
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
@@ -793,7 +776,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
793776
ggml_type_size(dst->type), src0->ne, src_trans_nb,
794777
GGML_MAX_DIMS);
795778

796-
aclnn_cast(ctx, acl_src, src_trans_tensor, dst);
779+
aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
797780
size_t cpy_size = ggml_nbytes(dst);
798781
ACL_CHECK(aclrtMemcpyAsync(
799782
dst->data, cpy_size, src_trans_buffer, cpy_size,
@@ -815,7 +798,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
815798
ggml_type_size(dst->type), src0->ne, src_trans_nb,
816799
GGML_MAX_DIMS);
817800

818-
aclnn_cast(ctx, acl_src, src_trans_tensor, dst);
801+
aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
819802

820803
size_t cpy_size = ggml_nbytes(dst);
821804
ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer,
@@ -1159,7 +1142,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
11591142
tmp_cast_buffer, ggml_cann_type_mapping(dst->type),
11601143
ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb,
11611144
GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
1162-
aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, dst);
1145+
aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type));
11631146
}
11641147

11651148
// post-processing
@@ -1734,7 +1717,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17341717
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
17351718
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
17361719
src0->ne, src_trans_nb, GGML_MAX_DIMS);
1737-
aclnn_cast(ctx, acl_src0, src_trans_tensor, dst);
1720+
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
17381721
aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
17391722
src_trans_nb, src1, dst);
17401723
ACL_CHECK(aclDestroyTensor(acl_src0));
@@ -2075,7 +2058,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
20752058
output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne,
20762059
output_cast_nb, GGML_MAX_DIMS);
20772060
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2078-
aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, dst);
2061+
aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
20792062

20802063
ACL_CHECK(aclDestroyTensor(acl_output_tensor));
20812064
ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
@@ -2162,7 +2145,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21622145

21632146
GGML_TENSOR_BINARY_OP_LOCALS
21642147

2165-
// theta_scale arange, [0,1,...,ne0/2]
2148+
// theta_scale arange, [0,1,...,ne00/2 - 1]
21662149
int64_t theta_scale_length = ne00 / 2;
21672150
ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
21682151
theta_scale_length * sizeof(float_t));
@@ -2291,7 +2274,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
22912274
// TODO: use ascendc
22922275
// Only test with LLAMA model.
22932276
ggml_tensor* src0 = dst->src[0]; // input
2294-
// ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now.
22952277

22962278
// param
22972279
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2345,7 +2327,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23452327
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
23462328
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
23472329
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2348-
theta_scale, freq_scale, attn_factor, is_neox);
2330+
theta_scale, freq_scale, attn_factor, is_neox);
23492331

23502332
aclTensor* acl_src = ggml_cann_create_tensor(src0);
23512333
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@@ -2522,46 +2504,52 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25222504
return;
25232505
#endif
25242506

2525-
// src0 == GGML_TYPE_F16
2526-
// TODO: optimization this `if` code
2527-
if (src0->type == GGML_TYPE_F16) {
2528-
ggml_cann_pool_alloc sin_final_allocator(
2529-
ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
2530-
ggml_cann_pool_alloc cos_final_allocator(
2531-
ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
2532-
void* sin_final_buffer = sin_final_allocator.get();
2533-
void* cos_final_buffer = cos_final_allocator.get();
2534-
2535-
int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
2536-
size_t sin_final_nb[GGML_MAX_DIMS];
2537-
sin_final_nb[0] = ggml_type_size(src0->type);
2538-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2539-
sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1];
2507+
// ggml_mode = 0 --> aclnn_model = 1
2508+
int64_t acl_mode = mode == 0 ? 1 : mode;
2509+
2510+
switch (src0->type) {
2511+
case GGML_TYPE_F32: {
2512+
GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
2513+
acl_sin_reshape_tensor, acl_mode, acl_dst);
2514+
break;
25402515
}
2541-
aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor(
2542-
sin_final_buffer, ggml_cann_type_mapping(src0->type),
2543-
ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
2544-
GGML_MAX_DIMS);
2545-
aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor(
2546-
cos_final_buffer, ggml_cann_type_mapping(src0->type),
2547-
ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
2548-
GGML_MAX_DIMS);
2549-
2550-
aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, dst);
2551-
aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, dst);
2552-
ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
2553-
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
2554-
acl_sin_reshape_tensor = acl_sin_final_tensor;
2555-
acl_cos_reshape_tensor = acl_cos_final_tensor;
2556-
}
2557-
2558-
int acl_mode = mode;
2559-
if (mode == 0) {
2560-
acl_mode = 1;
2516+
case GGML_TYPE_F16: {
2517+
ggml_cann_pool_alloc src_trans_allocator(
2518+
ctx.pool(), ggml_nelements(src0) * sizeof(float));
2519+
void* src_trans_buffer = src_trans_allocator.get();
2520+
ggml_cann_pool_alloc dst_trans_allocator(
2521+
ctx.pool(), ggml_nelements(dst) * sizeof(float));
2522+
void* dst_trans_buffer = dst_trans_allocator.get();
2523+
2524+
size_t src_trans_nb[GGML_MAX_DIMS];
2525+
src_trans_nb[0] = sizeof(float);
2526+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2527+
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2528+
}
2529+
2530+
aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor(
2531+
src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb,
2532+
GGML_MAX_DIMS);
2533+
2534+
aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor(
2535+
dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb,
2536+
GGML_MAX_DIMS);
2537+
2538+
aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
2539+
2540+
GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor,
2541+
acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor);
2542+
2543+
aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
2544+
2545+
ACL_CHECK(aclDestroyTensor(acl_src_trans_tensor));
2546+
ACL_CHECK(aclDestroyTensor(acl_dst_trans_tensor));
2547+
break;
2548+
}
2549+
default:
2550+
GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
2551+
break;
25612552
}
2562-
2563-
GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
2564-
acl_sin_reshape_tensor, acl_mode, acl_dst);
25652553
ACL_CHECK(aclDestroyTensor(acl_src));
25662554
ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
25672555
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));

0 commit comments

Comments
 (0)