@@ -30,6 +30,8 @@ namespace avx2 {
3030#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
3131#endif
3232
33+ static inline void zero_reg () { _mm256_zeroupper (); }
34+
3335static inline __m256i unpack_4bits (void * srcptr, __m256i mask) {
3436 auto raw_data = _mm_loadu_si128 (reinterpret_cast <__m128i*>(srcptr));
3537 auto ymm0 = _mm256_cvtepu8_epi16 (raw_data);
@@ -539,7 +541,8 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
539541 vout_y = _mm256_sub_epi8 (vout_y, vbias);
540542 _mm256_storeu_si256 ((__m256i*)(dstptr + i), vout_y);
541543 } else {
542- ref::decompress_kblock_s4_s8<1 , 1 >(srcptr + i / 2 , nullptr , dstptr + i, 0 , 0 , 0 , 0 , 1 , elesize - i, nullptr , 0 );
544+ ref::decompress_kblock_s4_s8<1 , 1 >(srcptr + i / 2 , nullptr , dstptr + i, 0 , 0 , 0 , 0 , 1 ,
545+ static_cast <int >(elesize - i), nullptr , 0 );
543546 }
544547 }
545548 return BTLA_CODE::Success;
@@ -732,15 +735,15 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
732735 size_t tmpsize) {
733736 int constexpr VBits = 256 ;
734737 int constexpr VElt = VBits / 8 ;
735- int i = 0 ;
738+ size_t i = 0 ;
736739 uint64_t mask0 = 0x0303030303030303 ;
737740 auto vmask0 = _mm256_set_epi64x (*(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0);
738741 auto vbias = _mm256_set1_epi8 (2 );
739742 auto vshift_y = _mm256_set_epi32 (6 , 4 , 2 , 0 , 6 , 4 , 2 , 0 );
740743 auto vsfhl_mask_y = _mm256_set_epi8 (15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 , 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 ,
741744 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
742745 auto vorder_y = _mm256_set_epi32 (1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 );
743- int elt_pad = utils::padto_le (unpack_elt, VElt);
746+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
744747 for (; i < elt_pad; i += VElt) {
745748 auto vout = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
746749 vout = _mm256_sub_epi8 (vout, vbias);
@@ -981,7 +984,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
981984 size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
982985 int constexpr VBits = 256 ;
983986 int constexpr VElt = VBits / 8 ;
984- int i = 0 ;
987+ size_t i = 0 ;
985988 uint64_t mask0 = 0x0303030303030303 ;
986989 auto vmask0 = _mm256_set_epi64x (*(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0);
987990 auto vbias = _mm256_set1_epi8 (4 );
@@ -994,7 +997,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
994997 const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
995998 const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
996999 const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
997- int elt_pad = utils::padto_le (unpack_elt, VElt);
1000+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
9981001 for (; i < elt_pad; i += VElt) {
9991002 auto vout = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
10001003 auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1213,15 +1216,15 @@ static inline BTLA_CODE decompress_s1_s8(utils::bit1x8* bit1ptr, int8_t* dstptr,
12131216 size_t tmpsize) {
12141217 int constexpr VBits = 256 ;
12151218 int constexpr VElt = VBits / 8 ;
1216- int i = 0 ;
1219+ size_t i = 0 ;
12171220 int constexpr FullRange = 1 << (1 - 1 );
12181221 auto vbias = _mm256_set1_epi8 (FullRange);
12191222
12201223 const __m256i highMask = _mm256_set1_epi8 (0x04 );
12211224 const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
12221225 const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
12231226 const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1224- int elt_pad = utils::padto_le (unpack_elt, VElt);
1227+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
12251228 for (; i < elt_pad; i += VElt) {
12261229 auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
12271230 vb1 = _mm256_srli_epi32 (vb1, 2 );
@@ -1460,7 +1463,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
14601463 size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
14611464 int constexpr VBits = 256 ;
14621465 int constexpr VElt = VBits / 8 ;
1463- int i = 0 ;
1466+ size_t i = 0 ;
14641467 int constexpr FullRange = 1 << (5 - 1 );
14651468 uint32_t mask = 0x0f0f0f0f ;
14661469 auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -1470,7 +1473,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
14701473 const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
14711474 const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
14721475 const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1473- int elt_pad = utils::padto_le (unpack_elt, VElt);
1476+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
14741477 for (; i < elt_pad; i += VElt) {
14751478 auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
14761479 auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1760,7 +1763,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
17601763 int8_t * dstptr, size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
17611764 int constexpr VBits = 256 ;
17621765 int constexpr VElt = VBits / 8 ;
1763- int i = 0 ;
1766+ size_t i = 0 ;
17641767 int constexpr FullRange = 1 << (7 - 1 );
17651768 uint32_t mask = 0x0f0f0f0f ;
17661769 auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -1777,7 +1780,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
17771780 const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
17781781 const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
17791782 const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1780- int elt_pad = utils::padto_le (unpack_elt, VElt);
1783+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
17811784 for (; i < elt_pad; i += VElt) {
17821785 auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
17831786 auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -2035,7 +2038,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
20352038 size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
20362039 int constexpr VBits = 256 ;
20372040 int constexpr VElt = VBits / 8 ;
2038- int i = 0 ;
2041+ size_t i = 0 ;
20392042 int constexpr FullRange = 1 << (6 - 1 );
20402043 uint32_t mask = 0x0f0f0f0f ;
20412044 auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -2047,7 +2050,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
20472050 auto vsfhl_mask_y = _mm256_set_epi8 (15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 , 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 ,
20482051 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
20492052 auto vorder_y = _mm256_set_epi32 (1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 );
2050- int elt_pad = utils::padto_le (unpack_elt, VElt);
2053+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
20512054 for (; i < elt_pad; i += VElt) {
20522055 auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
20532056 auto vb1 = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
@@ -3474,8 +3477,8 @@ inline __m256 exp_ps_0_1(const __m256 x) {
34743477 static const auto log2e = _mm256_set1_ps (v_log2e);
34753478 static const auto half = _mm256_set1_ps (.5f );
34763479
3477- static const auto upper_bound = _mm256_set1_ps (88.722838 ); // log(max_positive_float)
3478- static const auto lower_bound = _mm256_set1_ps (-87.336549 ); // log(min_positive_float)
3480+ static const auto upper_bound = _mm256_set1_ps (88 .722838f ); // log(max_positive_float)
3481+ static const auto lower_bound = _mm256_set1_ps (-87 .336549f ); // log(min_positive_float)
34793482 __m256 x1 = _mm256_min_ps (x, upper_bound);
34803483 x1 = _mm256_max_ps (x1, lower_bound);
34813484
0 commit comments