@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
410
410
411
411
#if !defined(__ARM_FEATURE_DOTPROD )
412
412
413
- inline static int32x4_t vdotq_s32 (int32x4_t acc , int8x16_t a , int8x16_t b ) {
413
+ inline static int32x4_t ggml_vdotq_s32 (int32x4_t acc , int8x16_t a , int8x16_t b ) {
414
414
const int16x8_t p0 = vmull_s8 (vget_low_s8 (a ), vget_low_s8 (b ));
415
415
const int16x8_t p1 = vmull_s8 (vget_high_s8 (a ), vget_high_s8 (b ));
416
416
417
417
return vaddq_s32 (acc , vaddq_s32 (vpaddlq_s16 (p0 ), vpaddlq_s16 (p1 )));
418
418
}
419
419
420
+ #else
421
+
422
+ #define ggml_vdotq_s32 (a , b , c ) vdotq_s32(a, b, c)
423
+
420
424
#endif
421
425
422
426
#endif
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
2481
2485
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2482
2486
2483
2487
// dot product into int32x4_t
2484
- const int32x4_t p_0 = vdotq_s32 ( vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l ), v0_0hs , v1_0h );
2485
- const int32x4_t p_1 = vdotq_s32 ( vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l ), v0_1hs , v1_1h );
2488
+ const int32x4_t p_0 = ggml_vdotq_s32 ( ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0l ), v0_0hs , v1_0h );
2489
+ const int32x4_t p_1 = ggml_vdotq_s32 ( ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1l ), v0_1hs , v1_1h );
2486
2490
2487
2491
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), GGML_FP16_TO_FP32 (x0 -> d )* GGML_FP16_TO_FP32 (y0 -> d ));
2488
2492
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), GGML_FP16_TO_FP32 (x1 -> d )* GGML_FP16_TO_FP32 (y1 -> d ));
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
2769
2773
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2770
2774
2771
2775
// dot product into int32x4_t
2772
- const int32x4_t p_0 = vdotq_s32 ( vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0l ), v0_0h , v1_0h );
2773
- const int32x4_t p_1 = vdotq_s32 ( vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1l ), v0_1h , v1_1h );
2776
+ const int32x4_t p_0 = ggml_vdotq_s32 ( ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0l ), v0_0h , v1_0h );
2777
+ const int32x4_t p_1 = ggml_vdotq_s32 ( ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1l ), v0_1h , v1_1h );
2774
2778
2775
2779
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (p_0 ), GGML_FP16_TO_FP32 (x0 -> d )* y0 -> d );
2776
2780
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (p_1 ), GGML_FP16_TO_FP32 (x1 -> d )* y1 -> d );
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
2936
2940
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2937
2941
2938
2942
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
2939
- vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
2940
- vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), GGML_FP16_TO_FP32 (x0 -> d )* GGML_FP16_TO_FP32 (y0 -> d ));
2943
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
2944
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), GGML_FP16_TO_FP32 (x0 -> d )* GGML_FP16_TO_FP32 (y0 -> d ));
2941
2945
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
2942
- vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
2943
- vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), GGML_FP16_TO_FP32 (x1 -> d )* GGML_FP16_TO_FP32 (y1 -> d ));
2946
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
2947
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), GGML_FP16_TO_FP32 (x1 -> d )* GGML_FP16_TO_FP32 (y1 -> d ));
2944
2948
}
2945
2949
2946
2950
* s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
3228
3232
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
3229
3233
3230
3234
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
3231
- vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
3232
- vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), GGML_FP16_TO_FP32 (x0 -> d )* y0 -> d );
3235
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
3236
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), GGML_FP16_TO_FP32 (x0 -> d )* y0 -> d );
3233
3237
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
3234
- vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
3235
- vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), GGML_FP16_TO_FP32 (x1 -> d )* y1 -> d );
3238
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
3239
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), GGML_FP16_TO_FP32 (x1 -> d )* y1 -> d );
3236
3240
}
3237
3241
3238
3242
* s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs0 + summs1 ;
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
3483
3487
const int8x16_t y1_1 = vld1q_s8 (y1 -> qs + 16 );
3484
3488
3485
3489
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
3486
- vdotq_s32 (vdupq_n_s32 (0 ), x0_0 , y0_0 ),
3487
- vdotq_s32 (vdupq_n_s32 (0 ), x0_1 , y0_1 ))), GGML_FP16_TO_FP32 (x0 -> d )* GGML_FP16_TO_FP32 (y0 -> d ));
3490
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), x0_0 , y0_0 ),
3491
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), x0_1 , y0_1 ))), GGML_FP16_TO_FP32 (x0 -> d )* GGML_FP16_TO_FP32 (y0 -> d ));
3488
3492
3489
3493
sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
3490
- vdotq_s32 (vdupq_n_s32 (0 ), x1_0 , y1_0 ),
3491
- vdotq_s32 (vdupq_n_s32 (0 ), x1_1 , y1_1 ))), GGML_FP16_TO_FP32 (x1 -> d )* GGML_FP16_TO_FP32 (y1 -> d ));
3494
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), x1_0 , y1_0 ),
3495
+ ggml_vdotq_s32 (vdupq_n_s32 (0 ), x1_1 , y1_1 ))), GGML_FP16_TO_FP32 (x1 -> d )* GGML_FP16_TO_FP32 (y1 -> d ));
3492
3496
}
3493
3497
3494
3498
* s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3598
3602
// We use this macro instead of a function call because for some reason
3599
3603
// the code runs 2-3% slower, even if the function is declared inline
3600
3604
#define MULTIPLY_ACCUM_WITH_SCALE (index )\
3601
- isum += vaddvq_s32(vdotq_s32 (vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3602
- isum += vaddvq_s32(vdotq_s32 (vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3605
+ isum += vaddvq_s32(ggml_vdotq_s32 (vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3606
+ isum += vaddvq_s32(ggml_vdotq_s32 (vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3603
3607
3604
3608
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE (shift , index )\
3605
3609
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3973
3977
q2bytes .val [2 ] = vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q2bits , 4 ), m3 ));
3974
3978
q2bytes .val [3 ] = vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q2bits , 6 ), m3 ));
3975
3979
3976
- isum1 += vaddvq_s32 (vdotq_s32 (vzero , q2bytes .val [0 ], q8bytes .val [0 ])) * scales [0 ];
3977
- isum2 += vaddvq_s32 (vdotq_s32 (vzero , q2bytes .val [1 ], q8bytes .val [1 ])) * scales [1 ];
3978
- isum1 += vaddvq_s32 (vdotq_s32 (vzero , q2bytes .val [2 ], q8bytes .val [2 ])) * scales [2 ];
3979
- isum2 += vaddvq_s32 (vdotq_s32 (vzero , q2bytes .val [3 ], q8bytes .val [3 ])) * scales [3 ];
3980
+ isum1 += vaddvq_s32 (ggml_vdotq_s32 (vzero , q2bytes .val [0 ], q8bytes .val [0 ])) * scales [0 ];
3981
+ isum2 += vaddvq_s32 (ggml_vdotq_s32 (vzero , q2bytes .val [1 ], q8bytes .val [1 ])) * scales [1 ];
3982
+ isum1 += vaddvq_s32 (ggml_vdotq_s32 (vzero , q2bytes .val [2 ], q8bytes .val [2 ])) * scales [2 ];
3983
+ isum2 += vaddvq_s32 (ggml_vdotq_s32 (vzero , q2bytes .val [3 ], q8bytes .val [3 ])) * scales [3 ];
3980
3984
3981
3985
sum += d * (isum1 + isum2 );
3982
3986
}
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4256
4260
q3bytes .val [2 ] = vsubq_s8 (vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q3bits .val [0 ], 2 ), m3b )), vreinterpretq_s8_u8 (q3h .val [2 ]));
4257
4261
q3bytes .val [3 ] = vsubq_s8 (vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q3bits .val [1 ], 2 ), m3b )), vreinterpretq_s8_u8 (q3h .val [3 ]));
4258
4262
4259
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes_1 .val [0 ])) * scale [0 ];
4260
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes_1 .val [1 ])) * scale [1 ];
4261
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes_1 .val [2 ])) * scale [2 ];
4262
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes_1 .val [3 ])) * scale [3 ];
4263
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes_1 .val [0 ])) * scale [0 ];
4264
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes_1 .val [1 ])) * scale [1 ];
4265
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes_1 .val [2 ])) * scale [2 ];
4266
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes_1 .val [3 ])) * scale [3 ];
4263
4267
4264
4268
scale += 4 ;
4265
4269
@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4273
4277
q3bytes .val [2 ] = vsubq_s8 (vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q3bits .val [0 ], 6 ), m3b )), vreinterpretq_s8_u8 (q3h .val [2 ]));
4274
4278
q3bytes .val [3 ] = vsubq_s8 (vreinterpretq_s8_u8 (vandq_u8 (vshrq_n_u8 (q3bits .val [1 ], 6 ), m3b )), vreinterpretq_s8_u8 (q3h .val [3 ]));
4275
4279
4276
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes_2 .val [0 ])) * scale [0 ];
4277
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes_2 .val [1 ])) * scale [1 ];
4278
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes_2 .val [2 ])) * scale [2 ];
4279
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes_2 .val [3 ])) * scale [3 ];
4280
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes_2 .val [0 ])) * scale [0 ];
4281
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes_2 .val [1 ])) * scale [1 ];
4282
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes_2 .val [2 ])) * scale [2 ];
4283
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes_2 .val [3 ])) * scale [3 ];
4280
4284
4281
4285
scale += 4 ;
4282
4286
@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4757
4761
q3bytes .val [2 ] = vreinterpretq_s8_u8 (vorrq_u8 (vandq_u8 (vshrq_n_u8 (q3bits , 4 ), m3b ), q3h .val [2 ]));
4758
4762
q3bytes .val [3 ] = vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q3bits , 6 ), q3h .val [3 ]));
4759
4763
4760
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes .val [0 ])) * scales [0 ];
4761
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes .val [1 ])) * scales [2 ];
4762
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes .val [2 ])) * scales [1 ];
4763
- isum += vaddvq_s32 (vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes .val [3 ])) * scales [3 ];
4764
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [0 ], q8bytes .val [0 ])) * scales [0 ];
4765
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [1 ], q8bytes .val [1 ])) * scales [2 ];
4766
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [2 ], q8bytes .val [2 ])) * scales [1 ];
4767
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q3bytes .val [3 ], q8bytes .val [3 ])) * scales [3 ];
4764
4768
4765
4769
sum += d * isum ;
4766
4770
@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5109
5113
q4bytes .val [0 ] = vreinterpretq_s8_u8 (vandq_u8 (q4bits .val [0 ], m4b ));
5110
5114
q4bytes .val [1 ] = vreinterpretq_s8_u8 (vandq_u8 (q4bits .val [1 ], m4b ));
5111
5115
5112
- const int32x4_t p1 = vdotq_s32 ( vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5116
+ const int32x4_t p1 = ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5113
5117
sumi1 += vaddvq_s32 (p1 ) * scales [2 * j + 0 ];
5114
5118
5115
5119
q8bytes = ggml_vld1q_s8_x2 (q8 ); q8 += 32 ;
5116
5120
q4bytes .val [0 ] = vreinterpretq_s8_u8 (vshrq_n_u8 (q4bits .val [0 ], 4 ));
5117
5121
q4bytes .val [1 ] = vreinterpretq_s8_u8 (vshrq_n_u8 (q4bits .val [1 ], 4 ));
5118
5122
5119
- const int32x4_t p2 = vdotq_s32 ( vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5123
+ const int32x4_t p2 = ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5120
5124
5121
5125
sumi2 += vaddvq_s32 (p2 ) * scales [2 * j + 1 ];
5122
5126
}
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5449
5453
q4bytes .val [0 ] = vreinterpretq_s8_u8 (vandq_u8 (q4bits .val [0 ], m4b ));
5450
5454
q4bytes .val [1 ] = vreinterpretq_s8_u8 (vandq_u8 (q4bits .val [1 ], m4b ));
5451
5455
5452
- const int32x4_t p1 = vdotq_s32 ( vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5456
+ const int32x4_t p1 = ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [0 ]), q4bytes .val [1 ], q8bytes .val [1 ]);
5453
5457
const int32_t sumi1 = vaddvq_s32 (p1 ) * scales [0 ];
5454
5458
5455
5459
q4bytes .val [0 ] = vreinterpretq_s8_u8 (vshrq_n_u8 (q4bits .val [0 ], 4 ));
5456
5460
q4bytes .val [1 ] = vreinterpretq_s8_u8 (vshrq_n_u8 (q4bits .val [1 ], 4 ));
5457
5461
5458
- const int32x4_t p2 = vdotq_s32 ( vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [2 ]), q4bytes .val [1 ], q8bytes .val [3 ]);
5462
+ const int32x4_t p2 = ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q4bytes .val [0 ], q8bytes .val [2 ]), q4bytes .val [1 ], q8bytes .val [3 ]);
5459
5463
const int32_t sumi2 = vaddvq_s32 (p2 ) * scales [1 ];
5460
5464
5461
5465
sumf += d * (sumi1 + sumi2 );
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
5722
5726
q5bytes .val [2 ] = vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q5bits .val [0 ], 4 ), q5h .val [2 ]));
5723
5727
q5bytes .val [3 ] = vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q5bits .val [1 ], 4 ), q5h .val [3 ]));
5724
5728
5725
- sumi += vaddvq_s32 (vdotq_s32 ( vdotq_s32 (mzero , q5bytes .val [0 ], q8bytes .val [0 ]), q5bytes .val [1 ], q8bytes .val [1 ])) * * scales ++ ;
5726
- sumi += vaddvq_s32 (vdotq_s32 ( vdotq_s32 (mzero , q5bytes .val [2 ], q8bytes .val [2 ]), q5bytes .val [3 ], q8bytes .val [3 ])) * * scales ++ ;
5729
+ sumi += vaddvq_s32 (ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q5bytes .val [0 ], q8bytes .val [0 ]), q5bytes .val [1 ], q8bytes .val [1 ])) * * scales ++ ;
5730
+ sumi += vaddvq_s32 (ggml_vdotq_s32 ( ggml_vdotq_s32 (mzero , q5bytes .val [2 ], q8bytes .val [2 ]), q5bytes .val [3 ], q8bytes .val [3 ])) * * scales ++ ;
5727
5731
}
5728
5732
5729
5733
sumf += d * sumi - dmin * sumi_mins ;
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
6112
6116
q5bytes .val [2 ] = vsubq_s8 (vreinterpretq_s8_u8 (vshrq_n_u8 (q5bits .val [0 ], 4 )), vreinterpretq_s8_u8 (q5h .val [2 ]));
6113
6117
q5bytes .val [3 ] = vsubq_s8 (vreinterpretq_s8_u8 (vshrq_n_u8 (q5bits .val [1 ], 4 )), vreinterpretq_s8_u8 (q5h .val [3 ]));
6114
6118
6115
- int32_t sumi1 = sc [0 ] * vaddvq_s32 (vdotq_s32 (mzero , q5bytes .val [0 ], q8bytes .val [0 ]));
6116
- int32_t sumi2 = sc [1 ] * vaddvq_s32 (vdotq_s32 (mzero , q5bytes .val [1 ], q8bytes .val [1 ]));
6117
- int32_t sumi3 = sc [2 ] * vaddvq_s32 (vdotq_s32 (mzero , q5bytes .val [2 ], q8bytes .val [2 ]));
6118
- int32_t sumi4 = sc [3 ] * vaddvq_s32 (vdotq_s32 (mzero , q5bytes .val [3 ], q8bytes .val [3 ]));
6119
+ int32_t sumi1 = sc [0 ] * vaddvq_s32 (ggml_vdotq_s32 (mzero , q5bytes .val [0 ], q8bytes .val [0 ]));
6120
+ int32_t sumi2 = sc [1 ] * vaddvq_s32 (ggml_vdotq_s32 (mzero , q5bytes .val [1 ], q8bytes .val [1 ]));
6121
+ int32_t sumi3 = sc [2 ] * vaddvq_s32 (ggml_vdotq_s32 (mzero , q5bytes .val [2 ], q8bytes .val [2 ]));
6122
+ int32_t sumi4 = sc [3 ] * vaddvq_s32 (ggml_vdotq_s32 (mzero , q5bytes .val [3 ], q8bytes .val [3 ]));
6119
6123
6120
6124
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4 );
6121
6125
}
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6399
6403
q6bytes .val [2 ] = vreinterpretq_s8_u8 (vorrq_u8 (vandq_u8 (q6bits .val [2 ], m4b ), q6h .val [2 ]));
6400
6404
q6bytes .val [3 ] = vreinterpretq_s8_u8 (vorrq_u8 (vandq_u8 (q6bits .val [3 ], m4b ), q6h .val [3 ]));
6401
6405
6402
- isum += vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6403
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6404
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6405
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6406
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6407
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6408
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6409
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6406
6410
6407
6411
scale += 4 ;
6408
6412
@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6426
6430
q6bytes .val [2 ] = vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q6bits .val [2 ], 4 ), q6h .val [2 ]));
6427
6431
q6bytes .val [3 ] = vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q6bits .val [3 ], 4 ), q6h .val [3 ]));
6428
6432
6429
- isum += vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6430
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6431
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6432
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6433
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6434
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6435
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6436
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6433
6437
scale += 4 ;
6434
6438
}
6435
6439
//sum += isum * d_all * y[i].d;
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6816
6820
q6bytes .val [2 ] = vsubq_s8 (vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q6bits .val [0 ], 4 ), q6h .val [2 ])), m32s );
6817
6821
q6bytes .val [3 ] = vsubq_s8 (vreinterpretq_s8_u8 (vorrq_u8 (vshrq_n_u8 (q6bits .val [1 ], 4 ), q6h .val [3 ])), m32s );
6818
6822
6819
- isum += vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6820
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6821
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6822
- vaddvq_s32 (vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6823
+ isum += vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [0 ], q8bytes .val [0 ])) * scale [0 ] +
6824
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [1 ], q8bytes .val [1 ])) * scale [1 ] +
6825
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [2 ], q8bytes .val [2 ])) * scale [2 ] +
6826
+ vaddvq_s32 (ggml_vdotq_s32 (vzero , q6bytes .val [3 ], q8bytes .val [3 ])) * scale [3 ];
6823
6827
6824
6828
sum += isum * d_all * y [i ].d ;
6825
6829
0 commit comments