5050#pragma GCC diagnostic ignored "-Wignored-attributes"
5151
5252#include " sgemm.h"
53- #include < algorithm>
5453#include " ggml-impl.h"
5554#include " ggml-quants.h"
5655
@@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
243242template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
244243class tinyBLAS {
245244 public:
246- tinyBLAS (int k,
247- const TA *A, int lda,
248- const TB *B, int ldb,
249- TC *C, int ldc,
245+ tinyBLAS (int64_t k,
246+ const TA *A, int64_t lda,
247+ const TB *B, int64_t ldb,
248+ TC *C, int64_t ldc,
250249 int ith, int nth)
251250 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
252251 }
253252
254- void matmul (int m, int n, int task) {
253+ void matmul (int64_t m, int64_t n, int task) {
255254 if (task == GGML_TASK_TYPE_COMPUTE)
256255 mnpack (0 , m, 0 , n);
257256 }
258257
259258 private:
260- NOINLINE void mnpack (int m0, int m, int n0, int n) {
261- int mc, nc, mp, np;
262- switch ((std::min (m - m0, 5 ) << 4 ) | std::min (n - n0, 5 )) {
259+ NOINLINE void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
260+ int64_t mc, nc, mp, np;
261+ switch ((MIN (m - m0, 5 ) << 4 ) | MIN (n - n0, 5 )) {
263262#if VECTOR_REGISTERS == 32
264263 case 0x55 :
265264 mc = 5 ;
@@ -409,38 +408,38 @@ class tinyBLAS {
409408 }
410409
411410 template <int RM, int RN>
412- NOINLINE void gemm (int m0, int m, int n0, int n) {
413- int ytiles = (m - m0) / RM;
414- int xtiles = (n - n0) / RN;
415- int tiles = xtiles * ytiles;
416- int duty = (tiles + nth - 1 ) / nth;
417- int start = duty * ith;
418- int end = start + duty;
411+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
412+ int64_t ytiles = (m - m0) / RM;
413+ int64_t xtiles = (n - n0) / RN;
414+ int64_t tiles = xtiles * ytiles;
415+ int64_t duty = (tiles + nth - 1 ) / nth;
416+ int64_t start = duty * ith;
417+ int64_t end = start + duty;
419418 if (end > tiles)
420419 end = tiles;
421- for (int job = start; job < end; ++job) {
422- int ii = m0 + job / xtiles * RM;
423- int jj = n0 + job % xtiles * RN;
420+ for (int64_t job = start; job < end; ++job) {
421+ int64_t ii = m0 + job / xtiles * RM;
422+ int64_t jj = n0 + job % xtiles * RN;
424423 D Cv[RN][RM] = {};
425- for (int l = 0 ; l < k; l += KN)
426- for (int j = 0 ; j < RN; ++j)
427- for (int i = 0 ; i < RM; ++i)
424+ for (int64_t l = 0 ; l < k; l += KN)
425+ for (int64_t j = 0 ; j < RN; ++j)
426+ for (int64_t i = 0 ; i < RM; ++i)
428427 Cv[j][i] = madd (load<V>(A + lda * (ii + i) + l),
429428 load<V>(B + ldb * (jj + j) + l),
430429 Cv[j][i]);
431- for (int j = 0 ; j < RN; ++j)
432- for (int i = 0 ; i < RM; ++i)
430+ for (int64_t j = 0 ; j < RN; ++j)
431+ for (int64_t i = 0 ; i < RM; ++i)
433432 C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
434433 }
435434 }
436435
437436 const TA *const A;
438437 const TB *const B;
439438 TC *const C;
440- const int k;
441- const int lda;
442- const int ldb;
443- const int ldc;
439+ const int64_t k;
440+ const int64_t lda;
441+ const int64_t ldb;
442+ const int64_t ldc;
444443 const int ith;
445444 const int nth;
446445};
@@ -452,23 +451,23 @@ class tinyBLAS {
452451template <typename TA>
453452class tinyBLAS_Q0_ARM {
454453 public:
455- tinyBLAS_Q0_ARM (int k,
456- const TA *A, int lda,
457- const block_q8_0 *B, int ldb,
458- float *C, int ldc,
454+ tinyBLAS_Q0_ARM (int64_t k,
455+ const TA *A, int64_t lda,
456+ const block_q8_0 *B, int64_t ldb,
457+ float *C, int64_t ldc,
459458 int ith, int nth)
460459 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
461460 }
462461
463- void matmul (int m, int n, int task) {
462+ void matmul (int64_t m, int64_t n, int task) {
464463 if (task == GGML_TASK_TYPE_COMPUTE)
465464 mnpack (0 , m, 0 , n);
466465 }
467466
468467 private:
469- NOINLINE void mnpack (int m0, int m, int n0, int n) {
470- int mc, nc, mp, np;
471- switch ((std::min (m - m0, 3 ) << 4 ) | std::min (n - n0, 3 )) {
468+ NOINLINE void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
469+ int64_t mc, nc, mp, np;
470+ switch ((MIN (m - m0, 3 ) << 4 ) | MIN (n - n0, 3ll )) {
472471 case 0x33 :
473472 mc = 3 ;
474473 nc = 3 ;
@@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM {
524523 }
525524
526525 template <int RM, int RN>
527- NOINLINE void gemm (int m0, int m, int n0, int n) {
528- int ytiles = (m - m0) / RM;
529- int xtiles = (n - n0) / RN;
530- int tiles = xtiles * ytiles;
531- int duty = (tiles + nth - 1 ) / nth;
532- int start = duty * ith;
533- int end = start + duty;
526+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
527+ int64_t ytiles = (m - m0) / RM;
528+ int64_t xtiles = (n - n0) / RN;
529+ int64_t tiles = xtiles * ytiles;
530+ int64_t duty = (tiles + nth - 1 ) / nth;
531+ int64_t start = duty * ith;
532+ int64_t end = start + duty;
534533 if (end > tiles)
535534 end = tiles;
536- for (int job = start; job < end; ++job) {
537- int ii = m0 + job / xtiles * RM;
538- int jj = n0 + job % xtiles * RN;
535+ for (int64_t job = start; job < end; ++job) {
536+ int64_t ii = m0 + job / xtiles * RM;
537+ int64_t jj = n0 + job % xtiles * RN;
539538 float32x4_t Cv[RN][RM] = {};
540- for (int l = 0 ; l < k; ++l)
541- for (int j = 0 ; j < RN; ++j)
542- for (int i = 0 ; i < RM; ++i)
539+ for (int64_t l = 0 ; l < k; ++l)
540+ for (int64_t j = 0 ; j < RN; ++j)
541+ for (int64_t i = 0 ; i < RM; ++i)
543542 Cv[j][i] = vmlaq_n_f32 (Cv[j][i],
544543 vcvtq_f32_s32 (vdotq_s32 (
545544 vdotq_s32 (vdupq_n_s32 (0 ),
@@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM {
549548 load_hi (B + ldb * (jj + j) + l))),
550549 unhalf (A[lda * (ii + i) + l].d ) *
551550 unhalf (B[ldb * (jj + j) + l].d ));
552- for (int j = 0 ; j < RN; ++j)
553- for (int i = 0 ; i < RM; ++i)
551+ for (int64_t j = 0 ; j < RN; ++j)
552+ for (int64_t i = 0 ; i < RM; ++i)
554553 C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
555554 }
556555 }
@@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM {
577576 const TA *const A;
578577 const block_q8_0 *const B;
579578 float *const C;
580- const int k;
581- const int lda;
582- const int ldb;
583- const int ldc;
579+ const int64_t k;
580+ const int64_t lda;
581+ const int64_t ldb;
582+ const int64_t ldc;
584583 const int ith;
585584 const int nth;
586585};
@@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM {
590589template <typename TA, typename TB, typename TC>
591590class tinyBLAS_Q0_AVX2 {
592591 public:
593- tinyBLAS_Q0_AVX2 (int k,
594- const TA *A, int lda,
595- const TB *B, int ldb,
596- TC *C, int ldc,
592+ tinyBLAS_Q0_AVX2 (int64_t k,
593+ const TA *A, int64_t lda,
594+ const TB *B, int64_t ldb,
595+ TC *C, int64_t ldc,
597596 int ith, int nth)
598597 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
599598 }
600599
601- void matmul (int m, int n, int task) {
600+ void matmul (int64_t m, int64_t n, int task) {
602601 if (task == GGML_TASK_TYPE_COMPUTE)
603602 mnpack (0 , m, 0 , n);
604603 }
605604
606605 private:
607- void mnpack (int m0, int m, int n0, int n) {
608- int mc, nc, mp, np;
609- switch ((std::min (m - m0, 4 ) << 4 ) | std::min (n - n0, 4 )) {
606+ void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
607+ int64_t mc, nc, mp, np;
608+ switch ((MIN (m - m0, 4 ) << 4 ) | MIN (n - n0, 4 )) {
610609#if VECTOR_REGISTERS == 32
611610 case 0x44 :
612611 mc = 4 ;
@@ -714,31 +713,31 @@ class tinyBLAS_Q0_AVX2 {
714713 }
715714
716715 template <int RM, int RN>
717- NOINLINE void gemm (int m0, int m, int n0, int n) {
718- int ytiles = (m - m0) / RM;
719- int xtiles = (n - n0) / RN;
720- int tiles = xtiles * ytiles;
721- int duty = (tiles + nth - 1 ) / nth;
722- int start = duty * ith;
723- int end = start + duty;
716+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
717+ int64_t ytiles = (m - m0) / RM;
718+ int64_t xtiles = (n - n0) / RN;
719+ int64_t tiles = xtiles * ytiles;
720+ int64_t duty = (tiles + nth - 1 ) / nth;
721+ int64_t start = duty * ith;
722+ int64_t end = start + duty;
724723 if (end > tiles)
725724 end = tiles;
726- for (int job = start; job < end; ++job) {
727- int ii = m0 + job / xtiles * RM;
728- int jj = n0 + job % xtiles * RN;
725+ for (int64_t job = start; job < end; ++job) {
726+ int64_t ii = m0 + job / xtiles * RM;
727+ int64_t jj = n0 + job % xtiles * RN;
729728 __m256 Cv[RN][RM] = {};
730- for (int l = 0 ; l < k; ++l)
731- for (int j = 0 ; j < RN; ++j)
732- for (int i = 0 ; i < RM; ++i)
729+ for (int64_t l = 0 ; l < k; ++l)
730+ for (int64_t j = 0 ; j < RN; ++j)
731+ for (int64_t i = 0 ; i < RM; ++i)
733732 Cv[j][i] = madd (_mm256_set1_ps (unhalf (A[lda * (ii + i) + l].d ) *
734733 unhalf (B[ldb * (jj + j) + l].d )),
735734 updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
736735 load (A + lda * (ii + i) + l)),
737736 _mm256_sign_epi8 (load (B + ldb * (jj + j) + l),
738737 load (A + lda * (ii + i) + l))),
739738 Cv[j][i]);
740- for (int j = 0 ; j < RN; ++j)
741- for (int i = 0 ; i < RM; ++i)
739+ for (int64_t j = 0 ; j < RN; ++j)
740+ for (int64_t i = 0 ; i < RM; ++i)
742741 C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
743742 }
744743 }
@@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 {
771770 const TA *const A;
772771 const TB *const B;
773772 TC *const C;
774- const int k;
775- const int lda;
776- const int ldb;
777- const int ldc;
773+ const int64_t k;
774+ const int64_t lda;
775+ const int64_t ldb;
776+ const int64_t ldc;
778777 const int ith;
779778 const int nth;
780779};
@@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
813812 * @param Ctype is GGML data type of `C`
814813 * @return true if this function was able to service the matmul request
815814 */
816- bool llamafile_sgemm (int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
817- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
815+ bool llamafile_sgemm (int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
816+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
818817
819818 assert (m >= 0 );
820819 assert (n >= 0 );
@@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
824823 assert (ldc >= m);
825824 assert (nth > 0 );
826825 assert (ith < nth);
827- assert (1ll * lda * m <= 0x7fffffff );
828- assert (1ll * ldb * n <= 0x7fffffff );
829- assert (1ll * ldc * n <= 0x7fffffff );
830826
831827 if (Ctype != GGML_TYPE_F32)
832828 return false ;
0 commit comments