Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ae5dcfb

Browse files
[BesTLA] Sync compiler's compatibility (#279)
* use xbyak as an external project. remove some warnings * runtime check compiler's ISAs * fix flags on gcc * add zero reg, remove warnings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused cmake files * relocate code * add ISA found to target * clang-format * fix mha compile bug. add bestla to python test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add correct instruction for avx512_fp16 * fix compile error of depracated template --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a067e0c commit ae5dcfb

23 files changed

+264
-9578
lines changed

.github/workflows/unit-test-llmruntime.yml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
branches: [main]
66
paths:
77
- neural_speed/**
8+
- bestla/**
89
- tests/**
910
- .github/workflows/unit-test-llmruntime.yml
1011
- .github/workflows/unitTest/**

bestla/CMakeLists.txt

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
cmake_minimum_required(VERSION 3.12)
2-
32
project(bestla LANGUAGES CXX VERSION 0.1.0)
3+
4+
include(cmake/FindSIMD.cmake)
5+
46
file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
5-
file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)
67

78
option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF)
89
option(BTLA_SYCL "Compile OpenMP thread pool if OMP can be found" OFF)
@@ -22,23 +23,41 @@ option(BTLA_UT_NOASAN "Disable sanitize" OFF)
2223
option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF)
2324
option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF)
2425

25-
26-
27-
26+
include(FetchContent)
27+
FetchContent_Declare(
28+
xbyak
29+
GIT_REPOSITORY https://github.com/herumi/xbyak.git
30+
GIT_TAG v7.06
31+
)
32+
FetchContent_MakeAvailable(xbyak)
2833

2934
add_library(${PROJECT_NAME} INTERFACE)
35+
target_link_libraries(${PROJECT_NAME} INTERFACE xbyak)
3036
add_library(neural_speed::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
3137
target_include_directories(
3238
${PROJECT_NAME} INTERFACE
3339
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
3440
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>"
3541
)
42+
43+
function(add_isa_def ARG)
44+
if(${${ARG}_FOUND})
45+
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_${ARG}_FOUND=1)
46+
else()
47+
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_${ARG}_FOUND=0)
48+
endif()
49+
endfunction()
50+
51+
foreach (ISA ${ISA_SET})
52+
add_isa_def(${ISA})
53+
endforeach()
54+
3655
set(sycl_headers)
3756
set(sycl_libs)
3857
if(BTLA_SYCL)
3958
include(cmake/sycl.cmake)
4059
file(GLOB sycl_headers ${PROJECT_NAME}/sycl/*.h ${PROJECT_NAME}/sycl/*.hpp)
41-
add_compile_definitions(BTLA_SYCL)
60+
target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_SYCL)
4261
list(APPEND sycl_libs IntelSYCL::SYCL_CXX)
4362
add_compile_options(-march=native)
4463
#add_link_options(-fsycl-targets=spir64 -Xsycl-target-backend "-options -ze-opt-large-register-file")

bestla/bestla/bestla_device.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,10 @@ class CpuRuntime {
496496

497497
inline void adjustPE(const BTLA_ISA isa, const float PE_) {
498498
// printf("Adjust:%d,%f\n",int(isa),PE_);
499-
PE[int(isa)] = PE[int(isa)] * PE_ * 0.7 + PE[int(isa)] * 0.3;
499+
PE[int(isa)] = PE[int(isa)] * PE_ * 0.7f + PE[int(isa)] * 0.3f;
500500
}
501501

502-
size_t mL2Cache, mL1Cache, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0;
502+
size_t mL2Cache = 0, mL1Cache = 0, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0;
503503
int P_core_num = 0, E_core_num = 0;
504504
bool mHybrid = false;
505505

@@ -530,8 +530,8 @@ class CpuRuntime {
530530
}
531531
}
532532
}
533-
float PE[int(BTLA_ISA::ISA_COUNT)];
534-
int maxThreads;
533+
float PE[int(BTLA_ISA::ISA_COUNT)] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};
534+
int maxThreads = 0;
535535
};
536536
} // namespace device
537537
} // namespace bestla

bestla/bestla/bestla_prologue_b.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class WeightKBlockNInteger {
384384
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
385385
if (i < rawnk_scale) {
386386
for (int j = 0; j < N; j++) {
387-
stor->template SPtr<utils::f8>()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
387+
stor->template SPtr<utils::f8>()[i * stor->mNPad + j] = static_cast<int>(scales[j * rawnk_scale + i]);
388388
}
389389
} else {
390390
std::memset(stor->template SPtr<utils::f8>() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::f8));
@@ -771,14 +771,15 @@ class WeightKBlockNInteger {
771771
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
772772
if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
773773
auto internal_n_offset = n_offset + i;
774+
int dq_offset = static_cast<int>(wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1);
774775
kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T,
775776
BTLA_DTYPE::S4_CLIP>(
776777
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
777778
i * KPad / 2,
778779
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize,
779780
wptr->template SPtr<uint8_t>(), wptr->template DQPtr<float>(), k_offset / _GemmCore_T::PACK_ROW,
780781
internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize,
781-
wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize);
782+
dq_offset, tmpcache, cachesize);
782783
} else {
783784
auto sptr = wptr->template SPtr<void>();
784785
kernel::wrapper::DecompressKBlockS4Fp<_GemmCore_T::PACK_ROW, _GemmCore_T::NTILE, _T>::template forward<ISA_T>(

bestla/bestla/bestla_utils.h

+10-35
Original file line numberDiff line numberDiff line change
@@ -60,41 +60,16 @@
6060

6161
// As long as the compiler supports the ISA, we will enable it.
6262
// Only the ISA you use in your project will be compiled.
63-
#ifdef __GNUC__
64-
#define CompileAVX512F() (__GNUC__ >= 6)
65-
#define CompileAVX512VNNI() (__GNUC__ >= 9)
66-
#define CompileAVX2() (__GNUC__ >= 5)
67-
#define CompileAVXVNNI() (__GNUC__ >= 11)
68-
#define CompileAMX() (__GNUC__ >= 11)
69-
#define CompileBF16() (__GNUC__ >= 11)
70-
#define CompileFP16() (__GNUC__ >= 13)
71-
#define CompileAMXBF16() (CompileAMX())
72-
#define CompileAMXINT8() (CompileAMX())
73-
#endif
74-
75-
#if defined(_MSC_VER) && !defined(__INTEL_LLVM_COMPILER)
76-
#define CompileAVX512F() _MSC_VER && (_MSC_VER >= 1911)
77-
#define CompileAVX512VNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
78-
#define CompileAVX2() _MSC_VER && (_MSC_VER >= 1900)
79-
#define CompileAVXVNNI() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
80-
#define CompileAMX() _MSC_VER && (_MSC_VER >= 1930) // TODO(Yu) check the minimum version
81-
#define CompileBF16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version
82-
#define CompileFP16() _MSC_VER && (_MSC_VER >= 1938) // TODO(Yu) check the minimum version
83-
#define CompileAMXBF16() (CompileAMX())
84-
#define CompileAMXINT8() (CompileAMX())
85-
#endif
86-
87-
#if defined(_MSC_VER) && defined(__INTEL_LLVM_COMPILER)
88-
#define CompileAVX512F() defined(__AVX512F__)
89-
#define CompileAVX512VNNI() defined(__AVX512VNNI__)
90-
#define CompileAVX2() defined(__AVX2__) && defined(__F16C__) && defined(__FMA__)
91-
#define CompileAVXVNNI() defined(__AVXVNNI__)
92-
#define CompileAMX() defined(__AMX_TILE__)
93-
#define CompileBF16() defined(__AVX512BF16__)
94-
#define CompileFP16() defined(__AVX512FP16__)
95-
#define CompileAMXBF16() (CompileAMX())
96-
#define CompileAMXINT8() (CompileAMX())
97-
#endif
63+
#define CompileAVX512F() BTLA_AVX512_FOUND
64+
#define CompileAVX512VNNI() BTLA_AVX512_VNNI_FOUND
65+
#define CompileAVX2() BTLA_AVX2_FOUND
66+
#define CompileAVXVNNI() BTLA_AVX_VNNI_FOUND
67+
#define CompileBF16() BTLA_AVX512_BF16_FOUND
68+
#define CompileFP16() BTLA_AVX512_FP16_FOUND
69+
#define CompileAMXBF16() BTLA_AMX_BF16_FOUND
70+
#define CompileAMXFP16() BTLA_AMX_FP16_FOUND
71+
#define CompileAMXINT8() BTLA_AMX_INT8_FOUND
72+
#define CompileAMX() BTLA_AMX_BF16_FOUND
9873

9974
// called by launcher, time critical functions
10075
#define TLACALL \

bestla/bestla/bestla_wrapper.h

+2
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ class LauncherBase {
462462
} else {
463463
gemm(_param, _config);
464464
}
465+
bestla::kernel::wrapper::ZeroReg::forward();
465466
}
466467

467468
protected:
@@ -709,6 +710,7 @@ class LauncherIntKBlock {
709710
} else {
710711
gemm(_param, _config);
711712
}
713+
bestla::kernel::wrapper::ZeroReg::forward();
712714
}
713715

714716
protected:

bestla/bestla/kernel_avx2.h

+18-15
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ namespace avx2 {
3030
#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
3131
#endif
3232

33+
static inline void zero_reg() { _mm256_zeroupper(); }
34+
3335
static inline __m256i unpack_4bits(void* srcptr, __m256i mask) {
3436
auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr));
3537
auto ymm0 = _mm256_cvtepu8_epi16(raw_data);
@@ -539,7 +541,8 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
539541
vout_y = _mm256_sub_epi8(vout_y, vbias);
540542
_mm256_storeu_si256((__m256i*)(dstptr + i), vout_y);
541543
} else {
542-
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1, elesize - i, nullptr, 0);
544+
ref::decompress_kblock_s4_s8<1, 1>(srcptr + i / 2, nullptr, dstptr + i, 0, 0, 0, 0, 1,
545+
static_cast<int>(elesize - i), nullptr, 0);
543546
}
544547
}
545548
return BTLA_CODE::Success;
@@ -732,15 +735,15 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
732735
size_t tmpsize) {
733736
int constexpr VBits = 256;
734737
int constexpr VElt = VBits / 8;
735-
int i = 0;
738+
size_t i = 0;
736739
uint64_t mask0 = 0x0303030303030303;
737740
auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
738741
auto vbias = _mm256_set1_epi8(2);
739742
auto vshift_y = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0);
740743
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
741744
13, 9, 5, 1, 12, 8, 4, 0);
742745
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
743-
int elt_pad = utils::padto_le(unpack_elt, VElt);
746+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
744747
for (; i < elt_pad; i += VElt) {
745748
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
746749
vout = _mm256_sub_epi8(vout, vbias);
@@ -981,7 +984,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
981984
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
982985
int constexpr VBits = 256;
983986
int constexpr VElt = VBits / 8;
984-
int i = 0;
987+
size_t i = 0;
985988
uint64_t mask0 = 0x0303030303030303;
986989
auto vmask0 = _mm256_set_epi64x(*(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0, *(int64_t*)&mask0);
987990
auto vbias = _mm256_set1_epi8(4);
@@ -994,7 +997,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
994997
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
995998
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
996999
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
997-
int elt_pad = utils::padto_le(unpack_elt, VElt);
1000+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
9981001
for (; i < elt_pad; i += VElt) {
9991002
auto vout = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
10001003
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1213,15 +1216,15 @@ static inline BTLA_CODE decompress_s1_s8(utils::bit1x8* bit1ptr, int8_t* dstptr,
12131216
size_t tmpsize) {
12141217
int constexpr VBits = 256;
12151218
int constexpr VElt = VBits / 8;
1216-
int i = 0;
1219+
size_t i = 0;
12171220
int constexpr FullRange = 1 << (1 - 1);
12181221
auto vbias = _mm256_set1_epi8(FullRange);
12191222

12201223
const __m256i highMask = _mm256_set1_epi8(0x04);
12211224
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
12221225
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
12231226
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
1224-
int elt_pad = utils::padto_le(unpack_elt, VElt);
1227+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
12251228
for (; i < elt_pad; i += VElt) {
12261229
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
12271230
vb1 = _mm256_srli_epi32(vb1, 2);
@@ -1460,7 +1463,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
14601463
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
14611464
int constexpr VBits = 256;
14621465
int constexpr VElt = VBits / 8;
1463-
int i = 0;
1466+
size_t i = 0;
14641467
int constexpr FullRange = 1 << (5 - 1);
14651468
uint32_t mask = 0x0f0f0f0f;
14661469
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
@@ -1470,7 +1473,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
14701473
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
14711474
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
14721475
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
1473-
int elt_pad = utils::padto_le(unpack_elt, VElt);
1476+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
14741477
for (; i < elt_pad; i += VElt) {
14751478
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
14761479
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1760,7 +1763,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
17601763
int8_t* dstptr, size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
17611764
int constexpr VBits = 256;
17621765
int constexpr VElt = VBits / 8;
1763-
int i = 0;
1766+
size_t i = 0;
17641767
int constexpr FullRange = 1 << (7 - 1);
17651768
uint32_t mask = 0x0f0f0f0f;
17661769
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
@@ -1777,7 +1780,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
17771780
const __m256i bit1Mask = _mm256_set1_epi32(0x0F);
17781781
const __m256i bit1Shift_1 = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
17791782
const __m256i bit1Shift_2 = _mm256_set1_epi32((1 << 23) + (1 << 16) + (1 << 9) + (1 << 2));
1780-
int elt_pad = utils::padto_le(unpack_elt, VElt);
1783+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
17811784
for (; i < elt_pad; i += VElt) {
17821785
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
17831786
auto vb1 = unpack_1bits(bit1ptr + i / 8, bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -2035,7 +2038,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
20352038
size_t unpack_elt, int8_t* tmp, size_t tmpsize) {
20362039
int constexpr VBits = 256;
20372040
int constexpr VElt = VBits / 8;
2038-
int i = 0;
2041+
size_t i = 0;
20392042
int constexpr FullRange = 1 << (6 - 1);
20402043
uint32_t mask = 0x0f0f0f0f;
20412044
auto vmask = _mm256_set1_epi32(*reinterpret_cast<int*>(&mask));
@@ -2047,7 +2050,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
20472050
auto vsfhl_mask_y = _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, 15, 11, 7, 3, 14, 10, 6, 2,
20482051
13, 9, 5, 1, 12, 8, 4, 0);
20492052
auto vorder_y = _mm256_set_epi32(1, 1, 1, 1, 0, 0, 0, 0);
2050-
int elt_pad = utils::padto_le(unpack_elt, VElt);
2053+
size_t elt_pad = utils::padto_le(unpack_elt, VElt);
20512054
for (; i < elt_pad; i += VElt) {
20522055
auto vout = unpack_4bits(bit4ptr + i / 2, vmask);
20532056
auto vb1 = unpack_2bits(bit2ptr + i / 4, vshift_y, vmask0, vsfhl_mask_y, vorder_y);
@@ -3474,8 +3477,8 @@ inline __m256 exp_ps_0_1(const __m256 x) {
34743477
static const auto log2e = _mm256_set1_ps(v_log2e);
34753478
static const auto half = _mm256_set1_ps(.5f);
34763479

3477-
static const auto upper_bound = _mm256_set1_ps(88.722838); // log(max_positive_float)
3478-
static const auto lower_bound = _mm256_set1_ps(-87.336549); // log(min_positive_float)
3480+
static const auto upper_bound = _mm256_set1_ps(88.722838f); // log(max_positive_float)
3481+
static const auto lower_bound = _mm256_set1_ps(-87.336549f); // log(min_positive_float)
34793482
__m256 x1 = _mm256_min_ps(x, upper_bound);
34803483
x1 = _mm256_max_ps(x1, lower_bound);
34813484

0 commit comments

Comments
 (0)