50
50
#pragma GCC diagnostic ignored "-Wignored-attributes"
51
51
52
52
#include " sgemm.h"
53
- #include < algorithm>
54
53
#include " ggml-impl.h"
55
54
#include " ggml-quants.h"
56
55
@@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
243
242
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
244
243
class tinyBLAS {
245
244
public:
246
- tinyBLAS (int k,
247
- const TA *A, int lda,
248
- const TB *B, int ldb,
249
- TC *C, int ldc,
245
+ tinyBLAS (int64_t k,
246
+ const TA *A, int64_t lda,
247
+ const TB *B, int64_t ldb,
248
+ TC *C, int64_t ldc,
250
249
int ith, int nth)
251
250
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
252
251
}
253
252
254
- void matmul (int m, int n, int task) {
253
+ void matmul (int64_t m, int64_t n, int task) {
255
254
if (task == GGML_TASK_TYPE_COMPUTE)
256
255
mnpack (0 , m, 0 , n);
257
256
}
258
257
259
258
private:
260
- NOINLINE void mnpack (int m0, int m, int n0, int n) {
261
- int mc, nc, mp, np;
262
- switch ((std::min (m - m0, 5 ) << 4 ) | std::min (n - n0, 5 )) {
259
+ NOINLINE void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
260
+ int64_t mc, nc, mp, np;
261
+ switch ((MIN (m - m0, 5 ) << 4 ) | MIN (n - n0, 5 )) {
263
262
#if VECTOR_REGISTERS == 32
264
263
case 0x55 :
265
264
mc = 5 ;
@@ -409,38 +408,38 @@ class tinyBLAS {
409
408
}
410
409
411
410
template <int RM, int RN>
412
- NOINLINE void gemm (int m0, int m, int n0, int n) {
413
- int ytiles = (m - m0) / RM;
414
- int xtiles = (n - n0) / RN;
415
- int tiles = xtiles * ytiles;
416
- int duty = (tiles + nth - 1 ) / nth;
417
- int start = duty * ith;
418
- int end = start + duty;
411
+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
412
+ int64_t ytiles = (m - m0) / RM;
413
+ int64_t xtiles = (n - n0) / RN;
414
+ int64_t tiles = xtiles * ytiles;
415
+ int64_t duty = (tiles + nth - 1 ) / nth;
416
+ int64_t start = duty * ith;
417
+ int64_t end = start + duty;
419
418
if (end > tiles)
420
419
end = tiles;
421
- for (int job = start; job < end; ++job) {
422
- int ii = m0 + job / xtiles * RM;
423
- int jj = n0 + job % xtiles * RN;
420
+ for (int64_t job = start; job < end; ++job) {
421
+ int64_t ii = m0 + job / xtiles * RM;
422
+ int64_t jj = n0 + job % xtiles * RN;
424
423
D Cv[RN][RM] = {};
425
- for (int l = 0 ; l < k; l += KN)
426
- for (int j = 0 ; j < RN; ++j)
427
- for (int i = 0 ; i < RM; ++i)
424
+ for (int64_t l = 0 ; l < k; l += KN)
425
+ for (int64_t j = 0 ; j < RN; ++j)
426
+ for (int64_t i = 0 ; i < RM; ++i)
428
427
Cv[j][i] = madd (load<V>(A + lda * (ii + i) + l),
429
428
load<V>(B + ldb * (jj + j) + l),
430
429
Cv[j][i]);
431
- for (int j = 0 ; j < RN; ++j)
432
- for (int i = 0 ; i < RM; ++i)
430
+ for (int64_t j = 0 ; j < RN; ++j)
431
+ for (int64_t i = 0 ; i < RM; ++i)
433
432
C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
434
433
}
435
434
}
436
435
437
436
const TA *const A;
438
437
const TB *const B;
439
438
TC *const C;
440
- const int k;
441
- const int lda;
442
- const int ldb;
443
- const int ldc;
439
+ const int64_t k;
440
+ const int64_t lda;
441
+ const int64_t ldb;
442
+ const int64_t ldc;
444
443
const int ith;
445
444
const int nth;
446
445
};
@@ -452,23 +451,23 @@ class tinyBLAS {
452
451
template <typename TA>
453
452
class tinyBLAS_Q0_ARM {
454
453
public:
455
- tinyBLAS_Q0_ARM (int k,
456
- const TA *A, int lda,
457
- const block_q8_0 *B, int ldb,
458
- float *C, int ldc,
454
+ tinyBLAS_Q0_ARM (int64_t k,
455
+ const TA *A, int64_t lda,
456
+ const block_q8_0 *B, int64_t ldb,
457
+ float *C, int64_t ldc,
459
458
int ith, int nth)
460
459
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
461
460
}
462
461
463
- void matmul (int m, int n, int task) {
462
+ void matmul (int64_t m, int64_t n, int task) {
464
463
if (task == GGML_TASK_TYPE_COMPUTE)
465
464
mnpack (0 , m, 0 , n);
466
465
}
467
466
468
467
private:
469
- NOINLINE void mnpack (int m0, int m, int n0, int n) {
470
- int mc, nc, mp, np;
471
- switch ((std::min (m - m0, 3 ) << 4 ) | std::min (n - n0, 3 )) {
468
+ NOINLINE void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
469
+ int64_t mc, nc, mp, np;
470
+ switch ((MIN (m - m0, 3 ) << 4 ) | MIN (n - n0, 3ll )) {
472
471
case 0x33 :
473
472
mc = 3 ;
474
473
nc = 3 ;
@@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM {
524
523
}
525
524
526
525
template <int RM, int RN>
527
- NOINLINE void gemm (int m0, int m, int n0, int n) {
528
- int ytiles = (m - m0) / RM;
529
- int xtiles = (n - n0) / RN;
530
- int tiles = xtiles * ytiles;
531
- int duty = (tiles + nth - 1 ) / nth;
532
- int start = duty * ith;
533
- int end = start + duty;
526
+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
527
+ int64_t ytiles = (m - m0) / RM;
528
+ int64_t xtiles = (n - n0) / RN;
529
+ int64_t tiles = xtiles * ytiles;
530
+ int64_t duty = (tiles + nth - 1 ) / nth;
531
+ int64_t start = duty * ith;
532
+ int64_t end = start + duty;
534
533
if (end > tiles)
535
534
end = tiles;
536
- for (int job = start; job < end; ++job) {
537
- int ii = m0 + job / xtiles * RM;
538
- int jj = n0 + job % xtiles * RN;
535
+ for (int64_t job = start; job < end; ++job) {
536
+ int64_t ii = m0 + job / xtiles * RM;
537
+ int64_t jj = n0 + job % xtiles * RN;
539
538
float32x4_t Cv[RN][RM] = {};
540
- for (int l = 0 ; l < k; ++l)
541
- for (int j = 0 ; j < RN; ++j)
542
- for (int i = 0 ; i < RM; ++i)
539
+ for (int64_t l = 0 ; l < k; ++l)
540
+ for (int64_t j = 0 ; j < RN; ++j)
541
+ for (int64_t i = 0 ; i < RM; ++i)
543
542
Cv[j][i] = vmlaq_n_f32 (Cv[j][i],
544
543
vcvtq_f32_s32 (vdotq_s32 (
545
544
vdotq_s32 (vdupq_n_s32 (0 ),
@@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM {
549
548
load_hi (B + ldb * (jj + j) + l))),
550
549
unhalf (A[lda * (ii + i) + l].d ) *
551
550
unhalf (B[ldb * (jj + j) + l].d ));
552
- for (int j = 0 ; j < RN; ++j)
553
- for (int i = 0 ; i < RM; ++i)
551
+ for (int64_t j = 0 ; j < RN; ++j)
552
+ for (int64_t i = 0 ; i < RM; ++i)
554
553
C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
555
554
}
556
555
}
@@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM {
577
576
const TA *const A;
578
577
const block_q8_0 *const B;
579
578
float *const C;
580
- const int k;
581
- const int lda;
582
- const int ldb;
583
- const int ldc;
579
+ const int64_t k;
580
+ const int64_t lda;
581
+ const int64_t ldb;
582
+ const int64_t ldc;
584
583
const int ith;
585
584
const int nth;
586
585
};
@@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM {
590
589
template <typename TA, typename TB, typename TC>
591
590
class tinyBLAS_Q0_AVX2 {
592
591
public:
593
- tinyBLAS_Q0_AVX2 (int k,
594
- const TA *A, int lda,
595
- const TB *B, int ldb,
596
- TC *C, int ldc,
592
+ tinyBLAS_Q0_AVX2 (int64_t k,
593
+ const TA *A, int64_t lda,
594
+ const TB *B, int64_t ldb,
595
+ TC *C, int64_t ldc,
597
596
int ith, int nth)
598
597
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
599
598
}
600
599
601
- void matmul (int m, int n, int task) {
600
+ void matmul (int64_t m, int64_t n, int task) {
602
601
if (task == GGML_TASK_TYPE_COMPUTE)
603
602
mnpack (0 , m, 0 , n);
604
603
}
605
604
606
605
private:
607
- void mnpack (int m0, int m, int n0, int n) {
608
- int mc, nc, mp, np;
609
- switch ((std::min (m - m0, 4 ) << 4 ) | std::min (n - n0, 4 )) {
606
+ void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
607
+ int64_t mc, nc, mp, np;
608
+ switch ((MIN (m - m0, 4 ) << 4 ) | MIN (n - n0, 4 )) {
610
609
#if VECTOR_REGISTERS == 32
611
610
case 0x44 :
612
611
mc = 4 ;
@@ -714,31 +713,31 @@ class tinyBLAS_Q0_AVX2 {
714
713
}
715
714
716
715
template <int RM, int RN>
717
- NOINLINE void gemm (int m0, int m, int n0, int n) {
718
- int ytiles = (m - m0) / RM;
719
- int xtiles = (n - n0) / RN;
720
- int tiles = xtiles * ytiles;
721
- int duty = (tiles + nth - 1 ) / nth;
722
- int start = duty * ith;
723
- int end = start + duty;
716
+ NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
717
+ int64_t ytiles = (m - m0) / RM;
718
+ int64_t xtiles = (n - n0) / RN;
719
+ int64_t tiles = xtiles * ytiles;
720
+ int64_t duty = (tiles + nth - 1 ) / nth;
721
+ int64_t start = duty * ith;
722
+ int64_t end = start + duty;
724
723
if (end > tiles)
725
724
end = tiles;
726
- for (int job = start; job < end; ++job) {
727
- int ii = m0 + job / xtiles * RM;
728
- int jj = n0 + job % xtiles * RN;
725
+ for (int64_t job = start; job < end; ++job) {
726
+ int64_t ii = m0 + job / xtiles * RM;
727
+ int64_t jj = n0 + job % xtiles * RN;
729
728
__m256 Cv[RN][RM] = {};
730
- for (int l = 0 ; l < k; ++l)
731
- for (int j = 0 ; j < RN; ++j)
732
- for (int i = 0 ; i < RM; ++i)
729
+ for (int64_t l = 0 ; l < k; ++l)
730
+ for (int64_t j = 0 ; j < RN; ++j)
731
+ for (int64_t i = 0 ; i < RM; ++i)
733
732
Cv[j][i] = madd (_mm256_set1_ps (unhalf (A[lda * (ii + i) + l].d ) *
734
733
unhalf (B[ldb * (jj + j) + l].d )),
735
734
updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
736
735
load (A + lda * (ii + i) + l)),
737
736
_mm256_sign_epi8 (load (B + ldb * (jj + j) + l),
738
737
load (A + lda * (ii + i) + l))),
739
738
Cv[j][i]);
740
- for (int j = 0 ; j < RN; ++j)
741
- for (int i = 0 ; i < RM; ++i)
739
+ for (int64_t j = 0 ; j < RN; ++j)
740
+ for (int64_t i = 0 ; i < RM; ++i)
742
741
C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
743
742
}
744
743
}
@@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 {
771
770
const TA *const A;
772
771
const TB *const B;
773
772
TC *const C;
774
- const int k;
775
- const int lda;
776
- const int ldb;
777
- const int ldc;
773
+ const int64_t k;
774
+ const int64_t lda;
775
+ const int64_t ldb;
776
+ const int64_t ldc;
778
777
const int ith;
779
778
const int nth;
780
779
};
@@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
813
812
* @param Ctype is GGML data type of `C`
814
813
* @return true if this function was able to service the matmul request
815
814
*/
816
- bool llamafile_sgemm (int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
817
- int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
815
+ bool llamafile_sgemm (int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
816
+ int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
818
817
819
818
assert (m >= 0 );
820
819
assert (n >= 0 );
@@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
824
823
assert (ldc >= m);
825
824
assert (nth > 0 );
826
825
assert (ith < nth);
827
- assert (1ll * lda * m <= 0x7fffffff );
828
- assert (1ll * ldb * n <= 0x7fffffff );
829
- assert (1ll * ldc * n <= 0x7fffffff );
830
826
831
827
if (Ctype != GGML_TYPE_F32)
832
828
return false ;
0 commit comments