diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index 82840b1fdd..63280b7c50 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -13,6 +13,7 @@ */ #include +#include #include #include namespace nntrainer::neon { @@ -700,194 +701,13 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows, } void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, - uint32_t rows, uint32_t cols, float alpha, + uint32_t M, uint32_t K, float alpha, float beta) { - if (cols % 8 == 0 && cols >= OMP_THRESHOLD) - sgemv_transpose_neon_fp16_multithread(A, X, Y, rows, cols, alpha, beta); - else - sgemv_transpose_neon_fp16_general(A, X, Y, rows, cols, alpha, beta); -} - -void sgemv_transpose_neon_fp16_multithread(const __fp16 *A, const __fp16 *X, - __fp16 *Y, uint32_t rows, - uint32_t cols, float alpha, - float beta) { - float Y32[cols]; - const int batch = 20; + float Y32[K]; unsigned int idx = 0; size_t NEON_NUM_THREADS = get_gemv_num_threads(); - for (unsigned int idx = 0; idx < cols; idx += 8) { - float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx])); - float32x4_t y4_7_32 = vcvt_f32_f16(vld1_f16(&Y[idx + 4])); - - y0_3_32 = vmulq_n_f32(y0_3_32, beta); - y4_7_32 = vmulq_n_f32(y4_7_32, beta); - - vst1q_f32(&Y32[idx], y0_3_32); - vst1q_f32(&Y32[idx + 4], y4_7_32); - } - if (rows / 16 >= batch) { - for (unsigned int i = 0; i < rows; i += 16) { - __fp16 x = alpha * (X[i]); - __fp16 x2 = alpha * (X[i + 1]); - __fp16 x3 = alpha * (X[i + 2]); - __fp16 x4 = alpha * (X[i + 3]); - __fp16 x5 = alpha * (X[i + 4]); - __fp16 x6 = alpha * (X[i + 5]); - __fp16 x7 = alpha * (X[i + 6]); - __fp16 x8 = alpha * (X[i + 7]); - __fp16 x9 = alpha * (X[i + 8]); - __fp16 x10 = alpha * (X[i + 9]); - __fp16 x11 = alpha * (X[i + 10]); - __fp16 x12 = alpha * (X[i + 11]); - __fp16 x13 = alpha * (X[i + 12]); - __fp16 x14 = alpha * (X[i + 13]); - __fp16 x15 = alpha * (X[i + 14]); - __fp16 x16 = alpha * (X[i + 15]); -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (int idx = 0; idx < cols; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 4) * cols + idx]), x5); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8); - - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 8) * cols + idx]), x9); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 9) * cols + idx]), x10); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - } - } else if (rows / 8 >= batch) { - for (unsigned int i = 0; i < rows; i += 8) { - __fp16 x = alpha * X[i]; - __fp16 x2 = alpha * X[i + 1]; - __fp16 x3 = alpha * X[i + 2]; - __fp16 x4 = alpha * X[i + 3]; - __fp16 x5 = alpha * X[i + 4]; - __fp16 x6 = alpha * X[i + 5]; - __fp16 x7 = alpha * X[i + 6]; - __fp16 x8 = alpha * X[i + 7]; - -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int idx = 0; idx < cols; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 4) * cols + idx]), x5); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - } - } else if (rows / 4 >= batch) { - for (unsigned int i = 0; i < rows; i += 4) { - __fp16 x = alpha * (X[i]); - __fp16 x2 = alpha * (X[i + 1]); - __fp16 x3 = alpha * (X[i + 2]); - __fp16 x4 = alpha * (X[i + 3]); - -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int idx = 0; idx < cols; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 2) * cols + idx]), x3); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - } - } else { - for (unsigned int i = 0; i < rows; ++i) { - __fp16 x = alpha * (X[i]); - idx = 0; -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int idx = 0; idx < cols; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - } - } - scopy_neon_fp32_to_fp16(cols, Y32, Y); - return; -} - -void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X, - __fp16 *Y, uint32_t rows, uint32_t cols, - float alpha, float beta) { - float Y32[cols]; - const int batch = 20; - unsigned int idx = 0; - for (; cols - idx >= 8; idx += 8) { + for (; K - idx >= 8; idx += 8) { float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx])); float32x4_t y4_7_32 = vcvt_f32_f16(vld1_f16(&Y[idx + 4])); @@ -897,403 +717,255 @@ void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X, vst1q_f32(&Y32[idx], y0_3_32); vst1q_f32(&Y32[idx + 4], y4_7_32); } - for (; cols - idx >= 4; idx += 4) { + for (; K - idx >= 4; idx += 4) { float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx])); y0_3_32 = vmulq_n_f32(y0_3_32, beta); vst1q_f32(&Y32[idx], y0_3_32); } - for (; cols - idx >= 1; idx += 1) { + for (; K - idx >= 1; idx += 1) { Y32[idx] = beta * Y[idx]; } - if (rows % 16 == 0 && rows / 16 >= batch && cols % 4 == 0) { - for (unsigned int i = 0; i < rows; i += 16) { - __fp16 x = alpha * (X[i]); - __fp16 x2 = alpha * (X[i + 1]); - __fp16 x3 = alpha * (X[i + 2]); - __fp16 x4 = alpha * (X[i + 3]); - __fp16 x5 = alpha * (X[i + 4]); - __fp16 x6 = alpha * (X[i + 5]); - __fp16 x7 = alpha * (X[i + 6]); - __fp16 x8 = alpha * (X[i + 7]); - __fp16 x9 = alpha * (X[i + 8]); - __fp16 x10 = alpha * (X[i + 9]); - __fp16 x11 = alpha * (X[i + 10]); - __fp16 x12 = alpha * (X[i + 11]); - __fp16 x13 = alpha * (X[i + 12]); - __fp16 x14 = alpha * (X[i + 13]); - __fp16 x15 = alpha * (X[i + 14]); - __fp16 x16 = alpha * (X[i + 15]); - - idx = 0; - for (; cols - idx >= 8; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 4) * cols + idx]), x5); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8); - - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 8) * cols + idx]), x9); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 9) * cols + idx]), x10); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } + unsigned int i = 0; + for (; M - i >= 8; i += 8) { + __fp16 x[8]; + vst1q_f16(&x[0], vmulq_n_f16(vld1q_f16(&A[i]), alpha)); + for (unsigned int idx = 0; idx < K - 8; idx += 8) { + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x[0]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * K + idx]), x[1]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * K + idx]), x[2]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * K + idx]), x[3]); + + float16x8_t w2vec0_7_f16 = + vmulq_n_f16(vld1q_f16(&A[(i + 4) * K + idx]), x[4]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 5) * K + idx]), x[5]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 6) * K + idx]), x[6]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 7) * K + idx]), x[7]); + + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } - for (; cols - idx >= 4; idx += 4) { - - float32x4_t y0_3 = vld1q_f32(&Y32[idx]); - - y0_3 = vfmaq_n_f32(y0_3, vcvt_f32_f16(vld1_f16(&A[i * cols + idx])), x); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 1) * cols + idx])), x2); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 2) * cols + idx])), x3); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 3) * cols + idx])), x4); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 4) * cols + idx])), x5); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 5) * cols + idx])), x6); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 6) * cols + idx])), x7); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 7) * cols + idx])), x8); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 8) * cols + idx])), x9); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 9) * cols + idx])), x10); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 10) * cols + idx])), x11); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 11) * cols + idx])), x12); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 12) * cols + idx])), x13); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 13) * cols + idx])), x14); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 14) * cols + idx])), x15); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 15) * cols + idx])), x16); - - vst1q_f32(&Y32[idx], y0_3); - } + if (K % 8 == 0) { + unsigned int idx = K - 8; + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x[0]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * K + idx]), x[1]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * K + idx]), x[2]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * K + idx]), x[3]); + + float16x8_t w2vec0_7_f16 = + vmulq_n_f16(vld1q_f16(&A[(i + 4) * K + idx]), x[4]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 5) * K + idx]), x[5]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 6) * K + idx]), x[6]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 7) * K + idx]), x[7]); + + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } - if (cols - idx >= 1) { - float y0_3_0[4]; - - float v0[4], v1[4], v2[4], v3[4]; - float v4[4], v5[4], v6[4], v7[4]; - float v8[4], v9[4], v10[4], v11[4]; - float v12[4], v13[4], v14[4], v15[4]; - - unsigned int k = 0; - for (; k < cols - idx; ++k) { - - y0_3_0[k] = Y32[idx + k]; - - v0[k] = A[i * cols + idx + k]; - v1[k] = A[(i + 1) * cols + idx + k]; - v2[k] = A[(i + 2) * cols + idx + k]; - v3[k] = A[(i + 3) * cols + idx + k]; - v4[k] = A[(i + 4) * cols + idx + k]; - v5[k] = A[(i + 5) * cols + idx + k]; - v6[k] = A[(i + 6) * cols + idx + k]; - v7[k] = A[(i + 7) * cols + idx + k]; - v8[k] = A[(i + 8) * cols + idx + k]; - v9[k] = A[(i + 9) * cols + idx + k]; - v10[k] = A[(i + 10) * cols + idx + k]; - v11[k] = A[(i + 11) * cols + idx + k]; - v12[k] = A[(i + 12) * cols + idx + k]; - v13[k] = A[(i + 13) * cols + idx + k]; - v15[k] = A[(i + 15) * cols + idx + k]; - } - for (; k < 4; ++k) { - y0_3_0[k] = 0; + else if (K % 8 != 0) { + unsigned int idx = 8 * (K / 8); - v0[k] = v1[k] = v2[k] = v3[k] = 0; - v4[k] = v5[k] = v6[k] = v7[k] = 0; - v8[k] = v9[k] = v10[k] = v11[k] = 0; - v12[k] = v13[k] = v14[k] = v15[k] = 0; - } + float y0_7[8]; + float v0[8], v1[8], v2[8], v3[8]; + float v4[8], v5[8], v6[8], v7[8]; - float32x4_t y0_3 = vld1q_f32(y0_3_0); - float32x4_t y0_3_tmp = vmovq_n_f32(0); - - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v0), x); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v1), x2); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v2), x3); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v3), x4); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v4), x5); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v5), x6); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v6), x7); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v7), x8); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v8), x9); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v9), x10); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v10), x11); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v11), x12); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v12), x13); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v13), x14); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v14), x15); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v15), x16); - - for (unsigned int k = 0; k < cols - idx; ++k) { - Y32[idx + k] = y0_3[k]; - } - } - } - } else if (rows % 8 == 0 && rows / 8 >= batch) { - for (unsigned int i = 0; i < rows; i += 8) { - __fp16 x = alpha * X[i]; - __fp16 x2 = alpha * X[i + 1]; - __fp16 x3 = alpha * X[i + 2]; - __fp16 x4 = alpha * X[i + 3]; - __fp16 x5 = alpha * X[i + 4]; - __fp16 x6 = alpha * X[i + 5]; - __fp16 x7 = alpha * X[i + 6]; - __fp16 x8 = alpha * X[i + 7]; - - idx = 0; - for (; cols - idx >= 8; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 4) * cols + idx]), x5); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); + unsigned int k = 0; + for (; k < K - idx; ++k) { + y0_7[k] = Y32[idx + k]; + + v0[k] = A[i * K + idx + k]; + v1[k] = A[(i + 1) * K + idx + k]; + v2[k] = A[(i + 2) * K + idx + k]; + v3[k] = A[(i + 3) * K + idx + k]; + v4[k] = A[(i + 4) * K + idx + k]; + v5[k] = A[(i + 5) * K + idx + k]; + v6[k] = A[(i + 6) * K + idx + k]; + v7[k] = A[(i + 7) * K + idx + k]; } - - for (; cols - idx >= 4; idx += 4) { - - float32x4_t y0_3 = vld1q_f32(&Y32[idx]); - y0_3 = vfmaq_n_f32(y0_3, vcvt_f32_f16(vld1_f16(&A[i * cols + idx])), x); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 1) * cols + idx])), x2); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 2) * cols + idx])), x3); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 3) * cols + idx])), x4); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 4) * cols + idx])), x5); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 5) * cols + idx])), x6); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 6) * cols + idx])), x7); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 7) * cols + idx])), x8); - vst1q_f32(&Y32[idx], y0_3); + for (; k < 8; ++k) { + y0_7[k] = 0; + v0[k] = v1[k] = v2[k] = v3[k] = 0; + v4[k] = v5[k] = v6[k] = v7[k] = 0; } - if (cols - idx >= 1) { - float y0_3_0[4]; - - float v0[4], v1[4], v2[4], v3[4]; - float v4[4], v5[4], v6[4], v7[4]; - - unsigned int k = 0; - for (; k < cols - idx; ++k) { - y0_3_0[k] = Y32[idx + k]; - - v0[k] = A[i * cols + idx + k]; - v1[k] = A[(i + 1) * cols + idx + k]; - v2[k] = A[(i + 2) * cols + idx + k]; - v3[k] = A[(i + 3) * cols + idx + k]; - v4[k] = A[(i + 4) * cols + idx + k]; - v5[k] = A[(i + 5) * cols + idx + k]; - v6[k] = A[(i + 6) * cols + idx + k]; - v7[k] = A[(i + 7) * cols + idx + k]; - } - for (; k < 4; ++k) { - y0_3_0[k] = 0; - v0[k] = v1[k] = v2[k] = v3[k] = 0; - v4[k] = v5[k] = v6[k] = v7[k] = 0; - } - - float32x4_t y0_3 = vld1q_f32(y0_3_0); - - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v0), x); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v1), x2); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v2), x3); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v3), x4); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v4), x5); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v5), x6); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v6), x7); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v7), x8); - - for (unsigned int k = 0; k < cols - idx; ++k) { + float32x4_t y0_3 = vld1q_f32(&y0_7[0]); + float32x4_t y4_7 = vld1q_f32(&y0_7[4]); + + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v0[0]), x[0]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v1[0]), x[1]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v2[0]), x[2]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v3[0]), x[3]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v4[0]), x[4]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v5[0]), x[5]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v6[0]), x[6]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v7[0]), x[7]); + + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v0[4]), x[0]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v1[4]), x[1]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v2[4]), x[2]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v3[4]), x[3]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v4[4]), x[4]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v5[4]), x[5]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v6[4]), x[6]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v7[4]), x[7]); + + for (unsigned int k = 0; k < K - idx; ++k) { + if (k < 4) Y32[idx + k] = y0_3[k]; - } + else + Y32[idx + k] = y4_7[k]; } } - } else if (rows % 4 == 0 && rows / 4 >= batch) { - for (unsigned int i = 0; i < rows; i += 4) { - __fp16 x = alpha * (X[i]); - __fp16 x2 = alpha * (X[i + 1]); - __fp16 x3 = alpha * (X[i + 2]); - __fp16 x4 = alpha * (X[i + 3]); - - idx = 0; - for (; cols - idx >= 8; idx += 8) { - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - wvec0_7_f16 = - vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); - float16x8_t w2vec0_7_f16 = - vmulq_n_f16(vld1q_f16(&A[(i + 2) * cols + idx]), x3); - w2vec0_7_f16 = - vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); - - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - - for (; cols - idx >= 4; idx += 4) { + } + for (; M - i >= 4; i += 4) { + __fp16 x[4]; + vst1_f16(&x[0], vmul_n_f16(vld1_f16(&A[i]), alpha)); + for (unsigned int idx = 0; idx < K - 8; idx += 8) { + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x[0]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * K + idx]), x[1]); + float16x8_t w2vec0_7_f16 = + vmulq_n_f16(vld1q_f16(&A[(i + 2) * K + idx]), x[2]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 3) * K + idx]), x[3]); + + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } - float32x4_t y0_3 = vld1q_f32(&Y32[idx]); + if (K % 8 == 0) { + unsigned int idx = K - 8; + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x[0]); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * K + idx]), x[1]); + float16x8_t w2vec0_7_f16 = + vmulq_n_f16(vld1q_f16(&A[(i + 2) * K + idx]), x[2]); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 3) * K + idx]), x[3]); + + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } - y0_3 = vfmaq_n_f32(y0_3, vcvt_f32_f16(vld1_f16(&A[i * cols + idx])), x); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 1) * cols + idx])), x2); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 2) * cols + idx])), x3); - y0_3 = vfmaq_n_f32( - y0_3, vcvt_f32_f16(vld1_f16(&A[(i + 3) * cols + idx])), x4); + else if (K % 8 != 0) { + unsigned int idx = 8 * (K / 8); + float y0_3_0[8]; + float v0[8], v1[8], v2[8], v3[8]; + unsigned int k = 0; + for (; k < K - idx; ++k) { + y0_3_0[k] = Y32[idx + k]; - vst1q_f32(&Y32[idx], y0_3); + v0[k] = A[i * K + idx + k]; + v1[k] = A[(i + 1) * K + idx + k]; + v2[k] = A[(i + 2) * K + idx + k]; + v3[k] = A[(i + 3) * K + idx + k]; + } + for (; k < 8; ++k) { + y0_3_0[k] = 0; + v0[k] = v1[k] = v2[k] = v3[k] = 0; } - if (cols - idx >= 1) { - float y0_3_0[4]; - - float v0[4], v1[4], v2[4], v3[4]; - unsigned int k = 0; - for (; k < cols - idx; ++k) { - y0_3_0[k] = Y32[idx + k]; - - v0[k] = A[i * cols + idx + k]; - v1[k] = A[(i + 1) * cols + idx + k]; - v2[k] = A[(i + 2) * cols + idx + k]; - v3[k] = A[(i + 3) * cols + idx + k]; - } - for (; k < 4; ++k) { - y0_3_0[k] = 0; - v0[k] = v1[k] = v2[k] = v3[k] = 0; - } + float32x4_t y0_3 = vld1q_f32(&y0_3_0[0]); + float32x4_t y4_7 = vld1q_f32(&y0_3_0[4]); - float32x4_t y0_3 = vld1q_f32(y0_3_0); + // we can separate mul and accum for faster compute. + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v0[0]), x[0]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v1[0]), x[1]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v2[0]), x[2]); + y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(&v3[0]), x[3]); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v0), x); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v1), x2); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v2), x3); - y0_3 = vfmaq_n_f32(y0_3, vld1q_f32(v3), x4); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v0[4]), x[0]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v1[4]), x[1]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v2[4]), x[2]); + y4_7 = vfmaq_n_f32(y4_7, vld1q_f32(&v3[4]), x[3]); - for (unsigned int k = 0; k < cols - idx; ++k) { + for (unsigned int k = 0; k < K - idx; ++k) { + if (k < 4) Y32[idx + k] = y0_3[k]; - } + else + Y32[idx + k] = y4_7[k]; } } - } else { - for (unsigned int i = 0; i < rows; ++i) { - __fp16 x = alpha * (X[i]); - idx = 0; - for (; cols - idx >= 8; idx += 8) { - - float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); - float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), - vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); - float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), - vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - - vst1q_f32(&Y32[idx], y0_3); - vst1q_f32(&Y32[idx + 4], y4_7); - } - - for (; cols - idx >= 4; idx += 4) { + } + for (; i < M; ++i) { + __fp16 x = alpha * (X[i]); + for (unsigned int idx = 0; idx < K - 8; idx += 8) { - float32x4_t y0_3 = vld1q_f32(&Y32[idx]); - y0_3 = vfmaq_n_f32(y0_3, vcvt_f32_f16(vld1_f16(&A[i * cols + idx])), x); + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x); + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); - vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } + if (K % 8 == 0) { + unsigned int idx = K - 8; + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * K + idx]), x); + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } else if (K % 8 != 0) { + unsigned int idx = 8 * (K / 8); + float v0[8]; + for (unsigned int j = 0; j < K - idx; ++j) { + v0[j] = A[i * K + idx + j]; + } + for (unsigned int j = K - idx; j < 8; ++j) { + v0[j] = 0; } - if (cols != idx) { - float y0_3[4]; - float wvec0_3[4]; - for (int j = 0; j < cols - idx; ++j) { - y0_3[j] = Y32[idx + j]; - wvec0_3[j] = A[i * cols + idx + j]; - } - for (int j = cols - idx; j < 4; ++j) { - y0_3[j] = 0; - wvec0_3[j] = 0; - } - - float32x4_t y0_3_32 = vld1q_f32(y0_3); - y0_3_32 = vfmaq_n_f32(y0_3_32, vld1q_f32(wvec0_3), x); - - for (int j = 0; j < cols - idx; ++j) { - Y32[idx + j] = y0_3_32[j]; - } + for (int j = 0; j < K - idx; ++j) { + Y32[idx + j] = Y32[idx + j] + v0[j] * x; } } } - scopy_neon_fp32_to_fp16(cols, Y32, Y); + scopy_neon_fp32_to_fp16(K, Y32, Y); return; } @@ -1771,18 +1443,9 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, free(C32); } -void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C32, +void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, uint32_t M, uint32_t N, uint32_t K, float alpha, float beta) { - if (N % 8 == 0 && N >= OMP_THRESHOLD) - sgemm_neon_fp16_noTrans_multithread(A, B, C32, M, N, K, alpha, beta); - else - sgemm_neon_fp16_noTrans_general(A, B, C32, M, N, K, alpha, beta); -} - -void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, - float *C, uint32_t M, uint32_t N, - uint32_t K, float alpha, float beta) { size_t NEON_NUM_THREADS = get_gemm_num_threads(); unsigned int k = 0; __fp16 a[16]; @@ -1791,7 +1454,7 @@ void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha)); #pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int n = 0; n < N; n += 8) { + for (unsigned int n = 0; n < N - 8; n += 8) { float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); @@ -1817,16 +1480,8 @@ void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } - } - } - - for (; (K - k) >= 8; k += 8) { - for (unsigned int m = 0; m < M; m++) { - vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); - -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int n = 0; n < N; n += 8) { - + if (N % 8 == 0) { + unsigned int n = N - 8; float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); @@ -1835,6 +1490,14 @@ void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]); float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7_0))); @@ -1844,143 +1507,167 @@ void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } - } - } - - for (; (K - k) >= 4; k += 4) { - for (unsigned int m = 0; m < M; m++) { - vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha)); - -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int n = 0; n < N; n += 8) { - - float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); - b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); - float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]); - b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]); - - float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), - vcvt_f32_f16(vget_low_f16(b0_7_0))); - float32x4_t c0_7_high_32 = vaddq_f32( - vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); - - c0_7_low_32 = - vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2))); - c0_7_high_32 = - vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2))); - - vst1q_f32(&C[m * N + n], c0_7_low_32); - vst1q_f32(&C[m * N + n + 4], c0_7_high_32); - } - } - } - - // remaining K values - for (; k < K; k++) { - for (unsigned int m = 0; m < M; m++) { - __fp16 a0 = alpha * A[m * K + k]; - -#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) - for (unsigned int n = 0; n < N; n += 8) { + // remaining N values :: should do for N and this time's 16-K + else if (N % 8 != 0) { + unsigned int n = 8 * (N / 8); + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + __fp16 valsB_4[8]; + __fp16 valsB_5[8]; + __fp16 valsB_6[8]; + __fp16 valsB_7[8]; + __fp16 valsB_8[8]; + __fp16 valsB_9[8]; + __fp16 valsB_10[8]; + __fp16 valsB_11[8]; + __fp16 valsB_12[8]; + __fp16 valsB_13[8]; + __fp16 valsB_14[8]; + __fp16 valsB_15[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsB_4[idx - n] = B[(k + 4) * N + idx]; + valsB_5[idx - n] = B[(k + 5) * N + idx]; + valsB_6[idx - n] = B[(k + 6) * N + idx]; + valsB_7[idx - n] = B[(k + 7) * N + idx]; + valsB_8[idx - n] = B[(k + 8) * N + idx]; + valsB_9[idx - n] = B[(k + 9) * N + idx]; + valsB_10[idx - n] = B[(k + 10) * N + idx]; + valsB_11[idx - n] = B[(k + 11) * N + idx]; + valsB_12[idx - n] = B[(k + 12) * N + idx]; + valsB_13[idx - n] = B[(k + 13) * N + idx]; + valsB_14[idx - n] = B[(k + 14) * N + idx]; + valsB_15[idx - n] = B[(k + 15) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } - float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0); + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]); float32x4_t c0_7_low_32 = - vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7))); + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); - float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]), - vcvt_f32_f16(vget_high_f16(b0_7))); + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); - vst1q_f32(&C[m * N + n], c0_7_low_32); - vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } } } } -} - -void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, - uint32_t M, uint32_t N, uint32_t K, - float alpha, float beta) { - unsigned int k = 0, n = 0; - __fp16 a[16]; - for (; (K - k) >= 16; k += 16) { + for (; (K - k) >= 8; k += 8) { for (unsigned int m = 0; m < M; m++) { vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); - vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha)); - - for (n = 0; (N - n) >= 8; n += 8) { +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N - 8; n += 8) { float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); - float16x8_t b0_7_4 = vmulq_n_f16(vld1q_f16(&B[(k + 4) * N + n]), a[4]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 5) * N + n]), a[5]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 6) * N + n]), a[6]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 7) * N + n]), a[7]); - float16x8_t b0_7_8 = vmulq_n_f16(vld1q_f16(&B[(k + 8) * N + n]), a[8]); - b0_7_8 = vfmaq_n_f16(b0_7_8, vld1q_f16(&B[(k + 9) * N + n]), a[9]); - b0_7_8 = vfmaq_n_f16(b0_7_8, vld1q_f16(&B[(k + 10) * N + n]), a[10]); - b0_7_8 = vfmaq_n_f16(b0_7_8, vld1q_f16(&B[(k + 11) * N + n]), a[11]); - float16x8_t b0_7_12 = - vmulq_n_f16(vld1q_f16(&B[(k + 12) * N + n]), a[12]); - b0_7_12 = vfmaq_n_f16(b0_7_12, vld1q_f16(&B[(k + 13) * N + n]), a[13]); - b0_7_12 = vfmaq_n_f16(b0_7_12, vld1q_f16(&B[(k + 14) * N + n]), a[14]); - b0_7_12 = vfmaq_n_f16(b0_7_12, vld1q_f16(&B[(k + 15) * N + n]), a[15]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7_0))); float32x4_t c0_7_high_32 = vaddq_f32( vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); - c0_7_low_32 = - vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4))); - c0_7_high_32 = - vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_4))); - c0_7_low_32 = - vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_8))); - c0_7_high_32 = - vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_8))); - c0_7_low_32 = - vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_12))); - c0_7_high_32 = - vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_12))); - vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } - } - } - - for (; (K - k) >= 8; k += 8) { - for (unsigned int m = 0; m < M; m++) { - vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); - - for (n = 0; (N - n) >= 8; n += 8) { - + if (N % 8 == 0) { + unsigned int n = N - 8; float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); - float16x8_t b0_7_4 = vmulq_n_f16(vld1q_f16(&B[(k + 4) * N + n]), a[4]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 5) * N + n]), a[5]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 6) * N + n]), a[6]); - b0_7_4 = vfmaq_n_f16(b0_7_4, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7_0))); float32x4_t c0_7_high_32 = vaddq_f32( vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); - c0_7_low_32 = - vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4))); - c0_7_high_32 = - vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_4))); - vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } + // remaining N values :: should do for N and this time's 8-K + else if (N % 8 != 0) { + unsigned int n = 8 * (N / 8); + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + __fp16 valsB_4[8]; + __fp16 valsB_5[8]; + __fp16 valsB_6[8]; + __fp16 valsB_7[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsB_4[idx - n] = B[(k + 4) * N + idx]; + valsB_5[idx - n] = B[(k + 5) * N + idx]; + valsB_6[idx - n] = B[(k + 6) * N + idx]; + valsB_7[idx - n] = B[(k + 7) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } } } @@ -1988,7 +1675,8 @@ void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, for (unsigned int m = 0; m < M; m++) { vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha)); - for (n = 0; (N - n) >= 8; n += 8) { +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N - 8; n += 8) { float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); @@ -2008,15 +1696,65 @@ void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } + if (N % 8 == 0) { + unsigned int n = N - 8; + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + // remaining N values :: should do for N and this time's 4-K + else if (N % 8 != 0) { + unsigned int n = 8 * (N / 8); + __fp16 valsB_0[8]; + __fp16 valsB_1[8]; + __fp16 valsB_2[8]; + __fp16 valsB_3[8]; + float valsC[8]; + for (unsigned int idx = n; idx < N; idx++) { + valsB_0[idx - n] = B[k * N + idx]; + valsB_1[idx - n] = B[(k + 1) * N + idx]; + valsB_2[idx - n] = B[(k + 2) * N + idx]; + valsB_3[idx - n] = B[(k + 3) * N + idx]; + valsC[idx - n] = C[m * N + idx]; + } + + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]); + b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); + + float32x4_t c0_7_high_32 = + vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b))); + + vst1q_f32(valsC, c0_7_low_32); + vst1q_f32(valsC + 4, c0_7_high_32); + + for (unsigned int idx = n; idx < N; idx++) { + C[m * N + idx] = valsC[idx - n]; + } + } } } + // remaining K values for (; k < K; k++) { for (unsigned int m = 0; m < M; m++) { __fp16 a0 = alpha * A[m * K + k]; - for (n = 0; (N - n) >= 8; n += 8) { - +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N - 8; n += 8) { float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0); float32x4_t c0_7_low_32 = @@ -2028,16 +1766,23 @@ void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, vst1q_f32(&C[m * N + n], c0_7_low_32); vst1q_f32(&C[m * N + n + 4], c0_7_high_32); } - } - } + if (N % 8 == 0) { + unsigned int n = N - 8; + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); - // remaining N values - if (n < N) { - __fp16 valsB[8]; - float valsC[8]; - for (k = 0; k < K; k++) { - for (unsigned int m = 0; m < M; m++) { - __fp16 a = alpha * A[m * K + k]; + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + // remaining N values :: should do for N and this time's 4-K + else if (N % 8 != 0) { + unsigned int n = 8 * (N / 8); + __fp16 valsB[8]; + float valsC[8]; for (unsigned int idx = n; idx < N; idx++) { valsB[idx - n] = B[k * N + idx]; @@ -2045,7 +1790,7 @@ void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, valsC[idx - n] = C[m * N + idx]; } - float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a); + float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0); float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b))); diff --git a/nntrainer/tensor/blas_neon.h b/nntrainer/tensor/blas_neon.h index 488b6b0d17..f8e48502e0 100644 --- a/nntrainer/tensor/blas_neon.h +++ b/nntrainer/tensor/blas_neon.h @@ -151,40 +151,6 @@ void elementwise_vector_addition_neon_fp16(const unsigned N, const __fp16 *X, void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows, uint32_t cols, float alpha, float beta); - -/** - * @brief transposed sgemv computation with neon - * Y = alpha*transpose(A)*X - * + beta*Y - * @param[in] A __fp16 * for Matrix A - * @param[in] X __fp16 * for Vector X - * @param[in] Y __fp16 * for Vector Y - * @param[in] rows number of A's row - * @param[in] cols number of A's columns - * @param[in] alpha float number - * @param[in] beta float number - */ -void sgemv_transpose_neon_fp16_multithread(const __fp16 *A, const __fp16 *X, - __fp16 *Y, uint32_t rows, - uint32_t cols, float alpha, - float beta); - -/** - * @brief transposed sgemv computation with neon - * Y = alpha*transpose(A)*X - * + beta*Y - * @param[in] A __fp16 * for Matrix A - * @param[in] X __fp16 * for Vector X - * @param[in] Y __fp16 * for Vector Y - * @param[in] rows number of A's row - * @param[in] cols number of A's columns - * @param[in] alpha float number - * @param[in] beta float number - */ -void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X, - __fp16 *Y, uint32_t rows, uint32_t cols, - float alpha, float beta); - /** * @brief saxpy computation with neon: Y = alpha*X + Y * @param[in] N number of elements in Y @@ -302,37 +268,6 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, uint32_t M, uint32_t N, uint32_t K, float alpha, float beta); - -/** - * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, - * where op(X) is one of X or X**T - * @param[in] A __fp16 * for Matrix A - * @param[in] B __fp16 * for Matrix B - * @param[in] C float * for Matrix C - * @param[in] M number of op(A)'s and C's row - * @param[in] N number of op(B)'s and C's columns - * @param[in] K number of op(A)'s and columns and op(B)'s rows - * @param[in] alpha float number - * @param[in] beta float number - */ -void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, - float *C, uint32_t M, uint32_t N, - uint32_t K, float alpha, float beta); -/** - * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, - * where op(X) is one of X or X**T - * @param[in] A __fp16 * for Matrix A - * @param[in] B __fp16 * for Matrix B - * @param[in] C float * for Matrix C - * @param[in] M number of op(A)'s and C's row - * @param[in] N number of op(B)'s and C's columns - * @param[in] K number of op(A)'s and columns and op(B)'s rows - * @param[in] alpha float number - * @param[in] beta float number - */ -void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, - uint32_t M, uint32_t N, uint32_t K, - float alpha, float beta); /** * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, * where op(X) is one of X or X**T