Skip to content

Commit 4b1c3c9

Browse files
authored
llamafile : use 64-bit integers in sgemm (ggml-org#6928)
1 parent bbe3c6e commit 4b1c3c9

File tree

2 files changed

+87
-89
lines changed

2 files changed

+87
-89
lines changed

sgemm.cpp

+83-87
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
#pragma GCC diagnostic ignored "-Wignored-attributes"
5151

5252
#include "sgemm.h"
53-
#include <algorithm>
5453
#include "ggml-impl.h"
5554
#include "ggml-quants.h"
5655

@@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
243242
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
244243
class tinyBLAS {
245244
public:
246-
tinyBLAS(int k,
247-
const TA *A, int lda,
248-
const TB *B, int ldb,
249-
TC *C, int ldc,
245+
tinyBLAS(int64_t k,
246+
const TA *A, int64_t lda,
247+
const TB *B, int64_t ldb,
248+
TC *C, int64_t ldc,
250249
int ith, int nth)
251250
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
252251
}
253252

254-
void matmul(int m, int n, int task) {
253+
void matmul(int64_t m, int64_t n, int task) {
255254
if (task == GGML_TASK_TYPE_COMPUTE)
256255
mnpack(0, m, 0, n);
257256
}
258257

259258
private:
260-
NOINLINE void mnpack(int m0, int m, int n0, int n) {
261-
int mc, nc, mp, np;
262-
switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
259+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
260+
int64_t mc, nc, mp, np;
261+
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
263262
#if VECTOR_REGISTERS == 32
264263
case 0x55:
265264
mc = 5;
@@ -409,38 +408,38 @@ class tinyBLAS {
409408
}
410409

411410
template <int RM, int RN>
412-
NOINLINE void gemm(int m0, int m, int n0, int n) {
413-
int ytiles = (m - m0) / RM;
414-
int xtiles = (n - n0) / RN;
415-
int tiles = xtiles * ytiles;
416-
int duty = (tiles + nth - 1) / nth;
417-
int start = duty * ith;
418-
int end = start + duty;
411+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
412+
int64_t ytiles = (m - m0) / RM;
413+
int64_t xtiles = (n - n0) / RN;
414+
int64_t tiles = xtiles * ytiles;
415+
int64_t duty = (tiles + nth - 1) / nth;
416+
int64_t start = duty * ith;
417+
int64_t end = start + duty;
419418
if (end > tiles)
420419
end = tiles;
421-
for (int job = start; job < end; ++job) {
422-
int ii = m0 + job / xtiles * RM;
423-
int jj = n0 + job % xtiles * RN;
420+
for (int64_t job = start; job < end; ++job) {
421+
int64_t ii = m0 + job / xtiles * RM;
422+
int64_t jj = n0 + job % xtiles * RN;
424423
D Cv[RN][RM] = {};
425-
for (int l = 0; l < k; l += KN)
426-
for (int j = 0; j < RN; ++j)
427-
for (int i = 0; i < RM; ++i)
424+
for (int64_t l = 0; l < k; l += KN)
425+
for (int64_t j = 0; j < RN; ++j)
426+
for (int64_t i = 0; i < RM; ++i)
428427
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
429428
load<V>(B + ldb * (jj + j) + l),
430429
Cv[j][i]);
431-
for (int j = 0; j < RN; ++j)
432-
for (int i = 0; i < RM; ++i)
430+
for (int64_t j = 0; j < RN; ++j)
431+
for (int64_t i = 0; i < RM; ++i)
433432
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
434433
}
435434
}
436435

437436
const TA *const A;
438437
const TB *const B;
439438
TC *const C;
440-
const int k;
441-
const int lda;
442-
const int ldb;
443-
const int ldc;
439+
const int64_t k;
440+
const int64_t lda;
441+
const int64_t ldb;
442+
const int64_t ldc;
444443
const int ith;
445444
const int nth;
446445
};
@@ -452,23 +451,23 @@ class tinyBLAS {
452451
template <typename TA>
453452
class tinyBLAS_Q0_ARM {
454453
public:
455-
tinyBLAS_Q0_ARM(int k,
456-
const TA *A, int lda,
457-
const block_q8_0 *B, int ldb,
458-
float *C, int ldc,
454+
tinyBLAS_Q0_ARM(int64_t k,
455+
const TA *A, int64_t lda,
456+
const block_q8_0 *B, int64_t ldb,
457+
float *C, int64_t ldc,
459458
int ith, int nth)
460459
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
461460
}
462461

463-
void matmul(int m, int n, int task) {
462+
void matmul(int64_t m, int64_t n, int task) {
464463
if (task == GGML_TASK_TYPE_COMPUTE)
465464
mnpack(0, m, 0, n);
466465
}
467466

468467
private:
469-
NOINLINE void mnpack(int m0, int m, int n0, int n) {
470-
int mc, nc, mp, np;
471-
switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
468+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
469+
int64_t mc, nc, mp, np;
470+
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
472471
case 0x33:
473472
mc = 3;
474473
nc = 3;
@@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM {
524523
}
525524

526525
template <int RM, int RN>
527-
NOINLINE void gemm(int m0, int m, int n0, int n) {
528-
int ytiles = (m - m0) / RM;
529-
int xtiles = (n - n0) / RN;
530-
int tiles = xtiles * ytiles;
531-
int duty = (tiles + nth - 1) / nth;
532-
int start = duty * ith;
533-
int end = start + duty;
526+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
527+
int64_t ytiles = (m - m0) / RM;
528+
int64_t xtiles = (n - n0) / RN;
529+
int64_t tiles = xtiles * ytiles;
530+
int64_t duty = (tiles + nth - 1) / nth;
531+
int64_t start = duty * ith;
532+
int64_t end = start + duty;
534533
if (end > tiles)
535534
end = tiles;
536-
for (int job = start; job < end; ++job) {
537-
int ii = m0 + job / xtiles * RM;
538-
int jj = n0 + job % xtiles * RN;
535+
for (int64_t job = start; job < end; ++job) {
536+
int64_t ii = m0 + job / xtiles * RM;
537+
int64_t jj = n0 + job % xtiles * RN;
539538
float32x4_t Cv[RN][RM] = {};
540-
for (int l = 0; l < k; ++l)
541-
for (int j = 0; j < RN; ++j)
542-
for (int i = 0; i < RM; ++i)
539+
for (int64_t l = 0; l < k; ++l)
540+
for (int64_t j = 0; j < RN; ++j)
541+
for (int64_t i = 0; i < RM; ++i)
543542
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
544543
vcvtq_f32_s32(vdotq_s32(
545544
vdotq_s32(vdupq_n_s32(0),
@@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM {
549548
load_hi(B + ldb * (jj + j) + l))),
550549
unhalf(A[lda * (ii + i) + l].d) *
551550
unhalf(B[ldb * (jj + j) + l].d));
552-
for (int j = 0; j < RN; ++j)
553-
for (int i = 0; i < RM; ++i)
551+
for (int64_t j = 0; j < RN; ++j)
552+
for (int64_t i = 0; i < RM; ++i)
554553
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
555554
}
556555
}
@@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM {
577576
const TA *const A;
578577
const block_q8_0 *const B;
579578
float *const C;
580-
const int k;
581-
const int lda;
582-
const int ldb;
583-
const int ldc;
579+
const int64_t k;
580+
const int64_t lda;
581+
const int64_t ldb;
582+
const int64_t ldc;
584583
const int ith;
585584
const int nth;
586585
};
@@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM {
590589
template <typename TA, typename TB, typename TC>
591590
class tinyBLAS_Q0_AVX2 {
592591
public:
593-
tinyBLAS_Q0_AVX2(int k,
594-
const TA *A, int lda,
595-
const TB *B, int ldb,
596-
TC *C, int ldc,
592+
tinyBLAS_Q0_AVX2(int64_t k,
593+
const TA *A, int64_t lda,
594+
const TB *B, int64_t ldb,
595+
TC *C, int64_t ldc,
597596
int ith, int nth)
598597
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
599598
}
600599

601-
void matmul(int m, int n, int task) {
600+
void matmul(int64_t m, int64_t n, int task) {
602601
if (task == GGML_TASK_TYPE_COMPUTE)
603602
mnpack(0, m, 0, n);
604603
}
605604

606605
private:
607-
void mnpack(int m0, int m, int n0, int n) {
608-
int mc, nc, mp, np;
609-
switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
606+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
607+
int64_t mc, nc, mp, np;
608+
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
610609
#if VECTOR_REGISTERS == 32
611610
case 0x44:
612611
mc = 4;
@@ -714,31 +713,31 @@ class tinyBLAS_Q0_AVX2 {
714713
}
715714

716715
template <int RM, int RN>
717-
NOINLINE void gemm(int m0, int m, int n0, int n) {
718-
int ytiles = (m - m0) / RM;
719-
int xtiles = (n - n0) / RN;
720-
int tiles = xtiles * ytiles;
721-
int duty = (tiles + nth - 1) / nth;
722-
int start = duty * ith;
723-
int end = start + duty;
716+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
717+
int64_t ytiles = (m - m0) / RM;
718+
int64_t xtiles = (n - n0) / RN;
719+
int64_t tiles = xtiles * ytiles;
720+
int64_t duty = (tiles + nth - 1) / nth;
721+
int64_t start = duty * ith;
722+
int64_t end = start + duty;
724723
if (end > tiles)
725724
end = tiles;
726-
for (int job = start; job < end; ++job) {
727-
int ii = m0 + job / xtiles * RM;
728-
int jj = n0 + job % xtiles * RN;
725+
for (int64_t job = start; job < end; ++job) {
726+
int64_t ii = m0 + job / xtiles * RM;
727+
int64_t jj = n0 + job % xtiles * RN;
729728
__m256 Cv[RN][RM] = {};
730-
for (int l = 0; l < k; ++l)
731-
for (int j = 0; j < RN; ++j)
732-
for (int i = 0; i < RM; ++i)
729+
for (int64_t l = 0; l < k; ++l)
730+
for (int64_t j = 0; j < RN; ++j)
731+
for (int64_t i = 0; i < RM; ++i)
733732
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
734733
unhalf(B[ldb * (jj + j) + l].d)),
735734
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
736735
load(A + lda * (ii + i) + l)),
737736
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
738737
load(A + lda * (ii + i) + l))),
739738
Cv[j][i]);
740-
for (int j = 0; j < RN; ++j)
741-
for (int i = 0; i < RM; ++i)
739+
for (int64_t j = 0; j < RN; ++j)
740+
for (int64_t i = 0; i < RM; ++i)
742741
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
743742
}
744743
}
@@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 {
771770
const TA *const A;
772771
const TB *const B;
773772
TC *const C;
774-
const int k;
775-
const int lda;
776-
const int ldb;
777-
const int ldc;
773+
const int64_t k;
774+
const int64_t lda;
775+
const int64_t ldb;
776+
const int64_t ldc;
778777
const int ith;
779778
const int nth;
780779
};
@@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
813812
* @param Ctype is GGML data type of `C`
814813
* @return true if this function was able to service the matmul request
815814
*/
816-
bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
817-
int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
815+
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
816+
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
818817

819818
assert(m >= 0);
820819
assert(n >= 0);
@@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
824823
assert(ldc >= m);
825824
assert(nth > 0);
826825
assert(ith < nth);
827-
assert(1ll * lda * m <= 0x7fffffff);
828-
assert(1ll * ldb * n <= 0x7fffffff);
829-
assert(1ll * ldc * n <= 0x7fffffff);
830826

831827
if (Ctype != GGML_TYPE_F32)
832828
return false;

sgemm.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#pragma once
2+
#include <stdint.h>
23
#include <stdbool.h>
34
#ifdef __cplusplus
45
extern "C" {
56
#endif
67

7-
bool llamafile_sgemm(int, int, int, const void *, int, const void *, int,
8-
void *, int, int, int, int, int, int, int);
8+
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9+
const void *, int64_t, void *, int64_t, int, int,
10+
int, int, int, int);
911

1012
#ifdef __cplusplus
1113
}

0 commit comments

Comments
 (0)