Skip to content

Commit 0001ec0

Browse files
amritahs-ibmggerganov
authored andcommitted
llamafile : ppc64le GEMV forwarding for FP32. (llama/12594)
This patch enables usage of MMA when one of the dimensions of the matrix(ie either M or N) is 1. This is useful in case of token generation where N < 2. The concept of 'GEMV Forwarding' is used where when one of the matrix has a single row/column, the elements are broadcasted, instead of using packing routine to prepack the matrix elements. This change results in 5% - 15% improvement in total speed(ie all tokens/total time), across various batch sizes. This is in comparision with the corresponding dot product implementation. The patch is tested with FP32 models of Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine. Signed-off-by: Amrita H S <[email protected]>
1 parent 5bad2e5 commit 0001ec0

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,13 +2680,25 @@ class tinyBLAS_PPC {
26802680
__builtin_mma_xxsetaccz(&acc_0);
26812681
vec_t vec_A[4] {0}, vec_B[4] = {0};
26822682
for (int l=0; l<k; l+=4) {
2683-
if (RN >= 4 && RM == 1) {
2683+
/* 'GEMV Forwarding' concept is used in first two conditional loops.
2684+
* when one of the matrix has a single row/column, the elements are
2685+
* broadcasted, instead of using packing routine to prepack the
2686+
* matrix elements.
2687+
*/
2688+
if (RM == 1) {
26842689
TA* a = const_cast<TA*>(A+(ii)*lda+l);
2685-
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2690+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
26862691
vec_A[0] = (vec_t)vec_xl(0,a);
26872692
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
26882693
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
26892694
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2695+
} else if (RN == 1) {
2696+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2697+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
2698+
vec_B[0] = (vec_t)vec_xl(0,b);
2699+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
2700+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
2701+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
26902702
} else {
26912703
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
26922704
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2790,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
27902802
assert(params->ith < params->nth);
27912803

27922804
// only enable sgemm for prompt processing
2805+
#if !defined(__MMA__)
27932806
if (n < 2)
27942807
return false;
2808+
#endif
27952809

27962810
if (Ctype != GGML_TYPE_F32)
27972811
return false;

0 commit comments

Comments
 (0)