Skip to content

Commit 453ae05

Browse files
committed
ggml : add ggml_vdotq_s32 alias
ggml-ci
1 parent 9fbda71 commit 453ae05

File tree

1 file changed

+61
-57
lines changed

1 file changed

+61
-57
lines changed

ggml-quants.c

+61-57
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,17 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
410410

411411
#if !defined(__ARM_FEATURE_DOTPROD)
412412

413-
inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
413+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
414414
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
415415
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
416416

417417
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
418418
}
419419

420+
#else
421+
422+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
423+
420424
#endif
421425

422426
#endif
@@ -2481,8 +2485,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
24812485
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
24822486

24832487
// dot product into int32x4_t
2484-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2485-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2488+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2489+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
24862490

24872491
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
24882492
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
@@ -2769,8 +2773,8 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
27692773
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
27702774

27712775
// dot product into int32x4_t
2772-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2773-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2776+
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2777+
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
27742778

27752779
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
27762780
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
@@ -2936,11 +2940,11 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
29362940
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
29372941

29382942
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2939-
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2940-
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2943+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2944+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
29412945
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2942-
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2943-
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2946+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2947+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
29442948
}
29452949

29462950
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3228,11 +3232,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
32283232
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
32293233

32303234
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3231-
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3232-
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3235+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3236+
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
32333237
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3234-
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3235-
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3238+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3239+
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
32363240
}
32373241

32383242
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@@ -3483,12 +3487,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
34833487
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
34843488

34853489
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3486-
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3487-
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3490+
ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3491+
ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
34883492

34893493
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3490-
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3491-
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3494+
ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3495+
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
34923496
}
34933497

34943498
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3598,8 +3602,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
35983602
// We use this macro instead of a function call because for some reason
35993603
// the code runs 2-3% slower, even if the function is declared inline
36003604
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
3601-
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3602-
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3605+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3606+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
36033607

36043608
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
36053609
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3973,10 +3977,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
39733977
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
39743978
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
39753979

3976-
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3977-
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3978-
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3979-
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
3980+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
3981+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
3982+
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
3983+
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
39803984

39813985
sum += d * (isum1 + isum2);
39823986
}
@@ -4256,10 +4260,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
42564260
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
42574261
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
42584262

4259-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4260-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4261-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4262-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4263+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4264+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4265+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4266+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
42634267

42644268
scale += 4;
42654269

@@ -4273,10 +4277,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
42734277
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
42744278
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
42754279

4276-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4277-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4278-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4279-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4280+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4281+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4282+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4283+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
42804284

42814285
scale += 4;
42824286

@@ -4757,10 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
47574761
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
47584762
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
47594763

4760-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4761-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4762-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4763-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4764+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4765+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4766+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4767+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
47644768

47654769
sum += d * isum;
47664770

@@ -5109,14 +5113,14 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
51095113
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
51105114
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
51115115

5112-
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5116+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
51135117
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
51145118

51155119
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
51165120
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
51175121
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
51185122

5119-
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5123+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
51205124

51215125
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
51225126
}
@@ -5449,13 +5453,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
54495453
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
54505454
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
54515455

5452-
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5456+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
54535457
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
54545458

54555459
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
54565460
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
54575461

5458-
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5462+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
54595463
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
54605464

54615465
sumf += d * (sumi1 + sumi2);
@@ -5722,8 +5726,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
57225726
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
57235727
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
57245728

5725-
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5726-
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5729+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5730+
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
57275731
}
57285732

57295733
sumf += d * sumi - dmin * sumi_mins;
@@ -6112,10 +6116,10 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
61126116
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
61136117
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
61146118

6115-
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6116-
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6117-
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6118-
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6119+
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6120+
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6121+
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6122+
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
61196123

61206124
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
61216125
}
@@ -6399,10 +6403,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
63996403
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
64006404
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
64016405

6402-
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6403-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6404-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6405-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6406+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6407+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6408+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6409+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
64066410

64076411
scale += 4;
64086412

@@ -6426,10 +6430,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
64266430
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
64276431
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
64286432

6429-
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6430-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6431-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6432-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6433+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6434+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6435+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6436+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
64336437
scale += 4;
64346438
}
64356439
//sum += isum * d_all * y[i].d;
@@ -6816,10 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
68166820
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
68176821
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
68186822

6819-
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6820-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6821-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6822-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6823+
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6824+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6825+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6826+
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
68236827

68246828
sum += isum * d_all * y[i].d;
68256829

0 commit comments

Comments
 (0)