Skip to content

Commit

Permalink
[ neon ] Apply use openMP in HGEMM, HGEMV in neon
Browse files Browse the repository at this point in the history
- In special occasion, we can enjoy computational profit with multithreading
- Settings in multithreading might differ.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Jan 18, 2024
1 parent 8f478cf commit ccfe651
Show file tree
Hide file tree
Showing 3 changed files with 364 additions and 10 deletions.
243 changes: 234 additions & 9 deletions nntrainer/tensor/blas_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/

#include <blas_neon.h>
#include <iostream>
#include <nntrainer_error.h>
#include <omp_setting.h>
namespace nntrainer::neon {

void sgemv_neon(const float *A, const float *X, float *Y, uint32_t rows,
Expand Down Expand Up @@ -701,6 +703,110 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
uint32_t rows, uint32_t cols, float alpha,
float beta) {
if (rows % 16 == 0 && cols % 8 == 0 && cols >= OMP_THRESHOLD)
sgemv_transpose_neon_fp16_multithread(A, X, Y, rows, cols, alpha, beta);
else
sgemv_transpose_neon_fp16_general(A, X, Y, rows, cols, alpha, beta);
}

void sgemv_transpose_neon_fp16_multithread(const __fp16 *A, const __fp16 *X,
__fp16 *Y, uint32_t rows,
uint32_t cols, float alpha,
float beta) {
float Y32[cols];
const int batch = 20;
unsigned int idx = 0;
size_t NEON_NUM_THREADS = get_gemv_num_threads();
for (; cols - idx >= 8; idx += 8) {
float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx]));
float32x4_t y4_7_32 = vcvt_f32_f16(vld1_f16(&Y[idx + 4]));

y0_3_32 = vmulq_n_f32(y0_3_32, beta);
y4_7_32 = vmulq_n_f32(y4_7_32, beta);

vst1q_f32(&Y32[idx], y0_3_32);
vst1q_f32(&Y32[idx + 4], y4_7_32);
}
for (; cols - idx >= 4; idx += 4) {
float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx]));
y0_3_32 = vmulq_n_f32(y0_3_32, beta);
vst1q_f32(&Y32[idx], y0_3_32);
}
for (; cols - idx >= 1; idx += 1) {
Y32[idx] = beta * Y[idx];
}
if (rows / 16 >= batch) {
for (unsigned int i = 0; i < rows; i += 16) {
__fp16 x = alpha * (X[i]);
__fp16 x2 = alpha * (X[i + 1]);
__fp16 x3 = alpha * (X[i + 2]);
__fp16 x4 = alpha * (X[i + 3]);
__fp16 x5 = alpha * (X[i + 4]);
__fp16 x6 = alpha * (X[i + 5]);
__fp16 x7 = alpha * (X[i + 6]);
__fp16 x8 = alpha * (X[i + 7]);
__fp16 x9 = alpha * (X[i + 8]);
__fp16 x10 = alpha * (X[i + 9]);
__fp16 x11 = alpha * (X[i + 10]);
__fp16 x12 = alpha * (X[i + 11]);
__fp16 x13 = alpha * (X[i + 12]);
__fp16 x14 = alpha * (X[i + 13]);
__fp16 x15 = alpha * (X[i + 14]);
__fp16 x16 = alpha * (X[i + 15]);
#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS)
for (int idx = 0; idx < cols; idx += 8) {
float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 4) * cols + idx]), x5);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7);
wvec0_7_f16 =
vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8);

float16x8_t w2vec0_7_f16 =
vmulq_n_f16(vld1q_f16(&A[(i + 8) * cols + idx]), x9);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 9) * cols + idx]), x10);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15);
w2vec0_7_f16 =
vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16);

float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]),
vcvt_f32_f16(vget_low_f16(wvec0_7_f16)));
y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16)));
float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]),
vcvt_f32_f16(vget_high_f16(wvec0_7_f16)));
y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16)));

vst1q_f32(&Y32[idx], y0_3);
vst1q_f32(&Y32[idx + 4], y4_7);
}
}
}
scopy_neon_fp32_to_fp16(cols, Y32, Y);
return;
}

void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X,
__fp16 *Y, uint32_t rows, uint32_t cols,
float alpha, float beta) {
float Y32[cols];
const int batch = 20;
unsigned int idx = 0;
Expand Down Expand Up @@ -1588,21 +1694,146 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
free(C32);
}

void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,
void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C32,
uint32_t M, uint32_t N, uint32_t K, float alpha,
float beta) {
if (N % 8 == 0 && N >= OMP_THRESHOLD)
sgemm_neon_fp16_noTrans_multithread(A, B, C32, M, N, K, alpha, beta);
else
sgemm_neon_fp16_noTrans_general(A, B, C32, M, N, K, alpha, beta);
}

void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B,
float *C, uint32_t M, uint32_t N,
uint32_t K, float alpha, float beta) {
size_t NEON_NUM_THREADS = get_gemm_num_threads();
unsigned int k = 0;
__fp16 a[16];
for (; (K - k) >= 16; k += 16) {
for (unsigned int m = 0; m < M; m++) {
vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS)
for (unsigned int n = 0; n < N; n += 8) {
float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);

float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
vcvt_f32_f16(vget_low_f16(b0_7_0)));
float32x4_t c0_7_high_32 = vaddq_f32(
vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));

vst1q_f32(&C[m * N + n], c0_7_low_32);
vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
}
}
}

for (; (K - k) >= 8; k += 8) {
for (unsigned int m = 0; m < M; m++) {
vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));

#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS)
for (unsigned int n = 0; n < N; n += 8) {

float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);

float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
vcvt_f32_f16(vget_low_f16(b0_7_0)));
float32x4_t c0_7_high_32 = vaddq_f32(
vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));

vst1q_f32(&C[m * N + n], c0_7_low_32);
vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
}
}
}

for (; (K - k) >= 4; k += 4) {
for (unsigned int m = 0; m < M; m++) {
vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));

#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS)
for (unsigned int n = 0; n < N; n += 8) {

float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);

float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
vcvt_f32_f16(vget_low_f16(b0_7_0)));
float32x4_t c0_7_high_32 = vaddq_f32(
vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));

c0_7_low_32 =
vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
c0_7_high_32 =
vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));

vst1q_f32(&C[m * N + n], c0_7_low_32);
vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
}
}
}

// remaining K values
for (; k < K; k++) {
for (unsigned int m = 0; m < M; m++) {
__fp16 a0 = alpha * A[m * K + k];

#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS)
for (unsigned int n = 0; n < N; n += 8) {

float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);

float32x4_t c0_7_low_32 =
vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));

float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
vcvt_f32_f16(vget_high_f16(b0_7)));

vst1q_f32(&C[m * N + n], c0_7_low_32);
vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
}
}
}
}

void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C,
uint32_t M, uint32_t N, uint32_t K,
float alpha, float beta) {

unsigned int k = 0, n = 0;
__fp16 a[16];
for (; (K - k) >= 16; k += 16) {
for (unsigned int m = 0; m < M; m++) {
// calculating A * alpha;
vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));

for (n = 0; (N - n) >= 8; n += 8) {

// fp16 multiplications and partial accumulations
float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
Expand All @@ -1626,7 +1857,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,
float32x4_t c0_7_high_32 = vaddq_f32(
vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));

// fp32 partial accumulations
c0_7_low_32 =
vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4)));
c0_7_high_32 =
Expand All @@ -1648,12 +1878,10 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,

for (; (K - k) >= 8; k += 8) {
for (unsigned int m = 0; m < M; m++) {
// calculating A * alpha;
vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));

for (n = 0; (N - n) >= 8; n += 8) {

// fp16 multiplications and partial accumulations
float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
Expand All @@ -1668,7 +1896,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,
float32x4_t c0_7_high_32 = vaddq_f32(
vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));

// fp32 partial accumulations
c0_7_low_32 =
vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4)));
c0_7_high_32 =
Expand All @@ -1682,7 +1909,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,

for (; (K - k) >= 4; k += 4) {
for (unsigned int m = 0; m < M; m++) {
// calculating A * alpha;
vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));

for (n = 0; (N - n) >= 8; n += 8) {
Expand All @@ -1708,7 +1934,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C,
}
}

// remaining K values
for (; k < K; k++) {
for (unsigned int m = 0; m < M; m++) {
__fp16 a0 = alpha * A[m * K + k];
Expand Down
Loading

0 comments on commit ccfe651

Please sign in to comment.