From 3b40a4137d49931d15ea4f0862c6ac52d7a35903 Mon Sep 17 00:00:00 2001 From: skykongkong8 Date: Thu, 18 Jan 2024 16:57:43 +0900 Subject: [PATCH] [ neon ] Apply use openMP in HGEMM, HGEMV in neon - In special occasion, we can enjoy computational profit with multithreading - Settings in multithreading might differ. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 --- nntrainer/tensor/blas_neon.cpp | 243 +++++++++++++++++++++++++++++++-- nntrainer/tensor/blas_neon.h | 66 ++++++++- nntrainer/tensor/omp_setting.h | 65 +++++++++ 3 files changed, 364 insertions(+), 10 deletions(-) create mode 100644 nntrainer/tensor/omp_setting.h diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index 8bf24c7cd5..6a3779d3e5 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -13,7 +13,9 @@ */ #include +#include #include +#include namespace nntrainer::neon { void sgemv_neon(const float *A, const float *X, float *Y, uint32_t rows, @@ -701,6 +703,110 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows, void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows, uint32_t cols, float alpha, float beta) { + if (rows % 16 == 0 && cols % 8 == 0 && cols >= OMP_THRESHOLD) + sgemv_transpose_neon_fp16_multithread(A, X, Y, rows, cols, alpha, beta); + else + sgemv_transpose_neon_fp16_general(A, X, Y, rows, cols, alpha, beta); +} + +void sgemv_transpose_neon_fp16_multithread(const __fp16 *A, const __fp16 *X, + __fp16 *Y, uint32_t rows, + uint32_t cols, float alpha, + float beta) { + float Y32[cols]; + const int batch = 20; + unsigned int idx = 0; + size_t NEON_NUM_THREADS = get_gemv_num_threads(); + for (; cols - idx >= 8; idx += 8) { + float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx])); + float32x4_t y4_7_32 = vcvt_f32_f16(vld1_f16(&Y[idx + 4])); + + y0_3_32 = vmulq_n_f32(y0_3_32, beta); + y4_7_32 = vmulq_n_f32(y4_7_32, beta); + + vst1q_f32(&Y32[idx], y0_3_32); + vst1q_f32(&Y32[idx + 4], y4_7_32); + } + for (; cols - idx >= 4; idx += 4) { + float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx])); + y0_3_32 = vmulq_n_f32(y0_3_32, beta); + vst1q_f32(&Y32[idx], y0_3_32); + } + for (; cols - idx >= 1; idx += 1) { + Y32[idx] = beta * Y[idx]; + } + if (rows / 16 >= batch) { + for (unsigned int i = 0; i < rows; i += 16) { + __fp16 x = alpha * (X[i]); + __fp16 x2 = alpha * (X[i + 1]); + __fp16 x3 = alpha * (X[i + 2]); + __fp16 x4 = alpha * (X[i + 3]); + __fp16 x5 = alpha * (X[i + 4]); + __fp16 x6 = alpha * (X[i + 5]); + __fp16 x7 = alpha * (X[i + 6]); + __fp16 x8 = alpha * (X[i + 7]); + __fp16 x9 = alpha * (X[i + 8]); + __fp16 x10 = alpha * (X[i + 9]); + __fp16 x11 = alpha * (X[i + 10]); + __fp16 x12 = alpha * (X[i + 11]); + __fp16 x13 = alpha * (X[i + 12]); + __fp16 x14 = alpha * (X[i + 13]); + __fp16 x15 = alpha * (X[i + 14]); + __fp16 x16 = alpha * (X[i + 15]); +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (int idx = 0; idx < cols; idx += 8) { + float16x8_t wvec0_7_f16 = vmulq_n_f16(vld1q_f16(&A[i * cols + idx]), x); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 1) * cols + idx]), x2); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 2) * cols + idx]), x3); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 3) * cols + idx]), x4); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 4) * cols + idx]), x5); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 5) * cols + idx]), x6); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 6) * cols + idx]), x7); + wvec0_7_f16 = + vfmaq_n_f16(wvec0_7_f16, vld1q_f16(&A[(i + 7) * cols + idx]), x8); + + float16x8_t w2vec0_7_f16 = + vmulq_n_f16(vld1q_f16(&A[(i + 8) * cols + idx]), x9); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 9) * cols + idx]), x10); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15); + w2vec0_7_f16 = + vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16); + + float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]), + vcvt_f32_f16(vget_low_f16(wvec0_7_f16))); + y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16))); + float32x4_t y4_7 = vaddq_f32(vld1q_f32(&Y32[idx + 4]), + vcvt_f32_f16(vget_high_f16(wvec0_7_f16))); + y4_7 = vaddq_f32(y4_7, vcvt_f32_f16(vget_high_f16(w2vec0_7_f16))); + + vst1q_f32(&Y32[idx], y0_3); + vst1q_f32(&Y32[idx + 4], y4_7); + } + } + } + scopy_neon_fp32_to_fp16(cols, Y32, Y); + return; +} + +void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X, + __fp16 *Y, uint32_t rows, uint32_t cols, + float alpha, float beta) { float Y32[cols]; const int batch = 20; unsigned int idx = 0; @@ -1588,21 +1694,146 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, free(C32); } -void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, +void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C32, uint32_t M, uint32_t N, uint32_t K, float alpha, float beta) { + if (N % 8 == 0 && N >= OMP_THRESHOLD) + sgemm_neon_fp16_noTrans_multithread(A, B, C32, M, N, K, alpha, beta); + else + sgemm_neon_fp16_noTrans_general(A, B, C32, M, N, K, alpha, beta); +} + +void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, + float *C, uint32_t M, uint32_t N, + uint32_t K, float alpha, float beta) { + size_t NEON_NUM_THREADS = get_gemm_num_threads(); + unsigned int k = 0; + __fp16 a[16]; + for (; (K - k) >= 16; k += 16) { + for (unsigned int m = 0; m < M; m++) { + vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); + vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha)); +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N; n += 8) { + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + } + } + + for (; (K - k) >= 8; k += 8) { + for (unsigned int m = 0; m < M; m++) { + vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); + +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N; n += 8) { + + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + } + } + + for (; (K - k) >= 4; k += 4) { + for (unsigned int m = 0; m < M; m++) { + vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha)); + +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N; n += 8) { + + float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); + b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); + float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]); + b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]); + + float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]), + vcvt_f32_f16(vget_low_f16(b0_7_0))); + float32x4_t c0_7_high_32 = vaddq_f32( + vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); + + c0_7_low_32 = + vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2))); + c0_7_high_32 = + vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + } + } + + // remaining K values + for (; k < K; k++) { + for (unsigned int m = 0; m < M; m++) { + __fp16 a0 = alpha * A[m * K + k]; + +#pragma omp parallel for schedule(guided) num_threads(NEON_NUM_THREADS) + for (unsigned int n = 0; n < N; n += 8) { + + float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0); + + float32x4_t c0_7_low_32 = + vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7))); + + float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]), + vcvt_f32_f16(vget_high_f16(b0_7))); + + vst1q_f32(&C[m * N + n], c0_7_low_32); + vst1q_f32(&C[m * N + n + 4], c0_7_high_32); + } + } + } +} + +void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, + uint32_t M, uint32_t N, uint32_t K, + float alpha, float beta) { unsigned int k = 0, n = 0; __fp16 a[16]; for (; (K - k) >= 16; k += 16) { for (unsigned int m = 0; m < M; m++) { - // calculating A * alpha; vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha)); for (n = 0; (N - n) >= 8; n += 8) { - // fp16 multiplications and partial accumulations float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); @@ -1626,7 +1857,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, float32x4_t c0_7_high_32 = vaddq_f32( vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); - // fp32 partial accumulations c0_7_low_32 = vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4))); c0_7_high_32 = @@ -1648,12 +1878,10 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, for (; (K - k) >= 8; k += 8) { for (unsigned int m = 0; m < M; m++) { - // calculating A * alpha; vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha)); for (n = 0; (N - n) >= 8; n += 8) { - // fp16 multiplications and partial accumulations float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]); b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]); @@ -1668,7 +1896,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, float32x4_t c0_7_high_32 = vaddq_f32( vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0))); - // fp32 partial accumulations c0_7_low_32 = vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_4))); c0_7_high_32 = @@ -1682,7 +1909,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, for (; (K - k) >= 4; k += 4) { for (unsigned int m = 0; m < M; m++) { - // calculating A * alpha; vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha)); for (n = 0; (N - n) >= 8; n += 8) { @@ -1708,7 +1934,6 @@ void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, } } - // remaining K values for (; k < K; k++) { for (unsigned int m = 0; m < M; m++) { __fp16 a0 = alpha * A[m * K + k]; diff --git a/nntrainer/tensor/blas_neon.h b/nntrainer/tensor/blas_neon.h index f9bd22d56f..488b6b0d17 100644 --- a/nntrainer/tensor/blas_neon.h +++ b/nntrainer/tensor/blas_neon.h @@ -152,6 +152,39 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows, uint32_t cols, float alpha, float beta); +/** + * @brief transposed sgemv computation with neon + * Y = alpha*transpose(A)*X + * + beta*Y + * @param[in] A __fp16 * for Matrix A + * @param[in] X __fp16 * for Vector X + * @param[in] Y __fp16 * for Vector Y + * @param[in] rows number of A's row + * @param[in] cols number of A's columns + * @param[in] alpha float number + * @param[in] beta float number + */ +void sgemv_transpose_neon_fp16_multithread(const __fp16 *A, const __fp16 *X, + __fp16 *Y, uint32_t rows, + uint32_t cols, float alpha, + float beta); + +/** + * @brief transposed sgemv computation with neon + * Y = alpha*transpose(A)*X + * + beta*Y + * @param[in] A __fp16 * for Matrix A + * @param[in] X __fp16 * for Vector X + * @param[in] Y __fp16 * for Vector Y + * @param[in] rows number of A's row + * @param[in] cols number of A's columns + * @param[in] alpha float number + * @param[in] beta float number + */ +void sgemv_transpose_neon_fp16_general(const __fp16 *A, const __fp16 *X, + __fp16 *Y, uint32_t rows, uint32_t cols, + float alpha, float beta); + /** * @brief saxpy computation with neon: Y = alpha*X + Y * @param[in] N number of elements in Y @@ -259,7 +292,7 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, * where op(X) is one of X or X**T * @param[in] A __fp16 * for Matrix A * @param[in] B __fp16 * for Matrix B - * @param[in] C __fp16 * for Matrix C + * @param[in] C float * for Matrix C * @param[in] M number of op(A)'s and C's row * @param[in] N number of op(B)'s and C's columns * @param[in] K number of op(A)'s and columns and op(B)'s rows @@ -269,6 +302,37 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, void sgemm_neon_fp16_noTrans(const __fp16 *A, const __fp16 *B, float *C, uint32_t M, uint32_t N, uint32_t K, float alpha, float beta); + +/** + * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, + * where op(X) is one of X or X**T + * @param[in] A __fp16 * for Matrix A + * @param[in] B __fp16 * for Matrix B + * @param[in] C float * for Matrix C + * @param[in] M number of op(A)'s and C's row + * @param[in] N number of op(B)'s and C's columns + * @param[in] K number of op(A)'s and columns and op(B)'s rows + * @param[in] alpha float number + * @param[in] beta float number + */ +void sgemm_neon_fp16_noTrans_multithread(const __fp16 *A, const __fp16 *B, + float *C, uint32_t M, uint32_t N, + uint32_t K, float alpha, float beta); +/** + * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, + * where op(X) is one of X or X**T + * @param[in] A __fp16 * for Matrix A + * @param[in] B __fp16 * for Matrix B + * @param[in] C float * for Matrix C + * @param[in] M number of op(A)'s and C's row + * @param[in] N number of op(B)'s and C's columns + * @param[in] K number of op(A)'s and columns and op(B)'s rows + * @param[in] alpha float number + * @param[in] beta float number + */ +void sgemm_neon_fp16_noTrans_general(const __fp16 *A, const __fp16 *B, float *C, + uint32_t M, uint32_t N, uint32_t K, + float alpha, float beta); /** * @brief sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C, * where op(X) is one of X or X**T diff --git a/nntrainer/tensor/omp_setting.h b/nntrainer/tensor/omp_setting.h new file mode 100644 index 0000000000..f2e5ed3f02 --- /dev/null +++ b/nntrainer/tensor/omp_setting.h @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Sungsik Kong + * + * @file omp_setting.h + * @date 18 Jan 2024 + * @see https://github.com/nnstreamer/nntrainer + * https://arxiv.org/abs/1706.03762 + * @author Sungsik Kong + * @bug No known bugs except for NYI items + * @brief This file is for OpenMP setting + * + */ + +#include + +/// @note This variable should be optimized by user +/// @todo Must find a general solution to optimize the functionality of +/// multithreading : determining the combination of #threads and size of +/// (M x K) x (K x N) GEMM +const int OMP_THRESHOLD = 20000; +/** + * @brief Function for setting the number of threads to use for GEMM + * + * @return size_t& num_threads + */ +inline size_t &GEMM_NUM_THREADS() { + /// @note This variable should be optimized by user + static size_t num_threads = 4; + return num_threads; +} +/** + * @brief Set the gemm num threads object + * + * @param n num_threads to set + */ +inline void set_gemm_num_threads(size_t n) { GEMM_NUM_THREADS() = n; } +/** + * @brief Get the gemm num threads object + * + * @return size_t num_threads + */ +inline size_t get_gemm_num_threads() { return GEMM_NUM_THREADS(); } +/** + * @brief Function for setting the number of threads to use for GEMV + * + * @return size_t& num_threads + */ +inline size_t &GEMV_NUM_THREADS() { + /// @note This variable should be optimized by user + static size_t num_threads = 2; + return num_threads; +} +/** + * @brief Set the gemv num threads object + * + * @param n num_threads to set + */ +inline void set_gemv_num_threads(size_t n) { GEMV_NUM_THREADS() = n; } +/** + * @brief Get the gemv num threads object + * + * @return size_t num_threads + */ +inline size_t get_gemv_num_threads() { return GEMV_NUM_THREADS(); }