@@ -30,6 +30,8 @@ namespace avx2 {
30
30
#pragma clang attribute push(__attribute__((target("avx,avx2,fma"))), apply_to = function)
31
31
#endif
32
32
33
+ static inline void zero_reg () { _mm256_zeroupper (); }
34
+
33
35
static inline __m256i unpack_4bits (void * srcptr, __m256i mask) {
34
36
auto raw_data = _mm_loadu_si128 (reinterpret_cast <__m128i*>(srcptr));
35
37
auto ymm0 = _mm256_cvtepu8_epi16 (raw_data);
@@ -539,7 +541,8 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr,
539
541
vout_y = _mm256_sub_epi8 (vout_y, vbias);
540
542
_mm256_storeu_si256 ((__m256i*)(dstptr + i), vout_y);
541
543
} else {
542
- ref::decompress_kblock_s4_s8<1 , 1 >(srcptr + i / 2 , nullptr , dstptr + i, 0 , 0 , 0 , 0 , 1 , elesize - i, nullptr , 0 );
544
+ ref::decompress_kblock_s4_s8<1 , 1 >(srcptr + i / 2 , nullptr , dstptr + i, 0 , 0 , 0 , 0 , 1 ,
545
+ static_cast <int >(elesize - i), nullptr , 0 );
543
546
}
544
547
}
545
548
return BTLA_CODE::Success;
@@ -732,15 +735,15 @@ static inline BTLA_CODE decompress_s2_s8(utils::bit2x4* bit2ptr, int8_t* dstptr,
732
735
size_t tmpsize) {
733
736
int constexpr VBits = 256 ;
734
737
int constexpr VElt = VBits / 8 ;
735
- int i = 0 ;
738
+ size_t i = 0 ;
736
739
uint64_t mask0 = 0x0303030303030303 ;
737
740
auto vmask0 = _mm256_set_epi64x (*(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0);
738
741
auto vbias = _mm256_set1_epi8 (2 );
739
742
auto vshift_y = _mm256_set_epi32 (6 , 4 , 2 , 0 , 6 , 4 , 2 , 0 );
740
743
auto vsfhl_mask_y = _mm256_set_epi8 (15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 , 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 ,
741
744
13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
742
745
auto vorder_y = _mm256_set_epi32 (1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 );
743
- int elt_pad = utils::padto_le (unpack_elt, VElt);
746
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
744
747
for (; i < elt_pad; i += VElt) {
745
748
auto vout = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
746
749
vout = _mm256_sub_epi8 (vout, vbias);
@@ -981,7 +984,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
981
984
size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
982
985
int constexpr VBits = 256 ;
983
986
int constexpr VElt = VBits / 8 ;
984
- int i = 0 ;
987
+ size_t i = 0 ;
985
988
uint64_t mask0 = 0x0303030303030303 ;
986
989
auto vmask0 = _mm256_set_epi64x (*(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0, *(int64_t *)&mask0);
987
990
auto vbias = _mm256_set1_epi8 (4 );
@@ -994,7 +997,7 @@ static inline BTLA_CODE decompress_s3_s8(utils::bit2x4* bit2ptr, utils::bit1x8*
994
997
const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
995
998
const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
996
999
const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
997
- int elt_pad = utils::padto_le (unpack_elt, VElt);
1000
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
998
1001
for (; i < elt_pad; i += VElt) {
999
1002
auto vout = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
1000
1003
auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1213,15 +1216,15 @@ static inline BTLA_CODE decompress_s1_s8(utils::bit1x8* bit1ptr, int8_t* dstptr,
1213
1216
size_t tmpsize) {
1214
1217
int constexpr VBits = 256 ;
1215
1218
int constexpr VElt = VBits / 8 ;
1216
- int i = 0 ;
1219
+ size_t i = 0 ;
1217
1220
int constexpr FullRange = 1 << (1 - 1 );
1218
1221
auto vbias = _mm256_set1_epi8 (FullRange);
1219
1222
1220
1223
const __m256i highMask = _mm256_set1_epi8 (0x04 );
1221
1224
const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
1222
1225
const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
1223
1226
const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1224
- int elt_pad = utils::padto_le (unpack_elt, VElt);
1227
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
1225
1228
for (; i < elt_pad; i += VElt) {
1226
1229
auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
1227
1230
vb1 = _mm256_srli_epi32 (vb1, 2 );
@@ -1460,7 +1463,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
1460
1463
size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
1461
1464
int constexpr VBits = 256 ;
1462
1465
int constexpr VElt = VBits / 8 ;
1463
- int i = 0 ;
1466
+ size_t i = 0 ;
1464
1467
int constexpr FullRange = 1 << (5 - 1 );
1465
1468
uint32_t mask = 0x0f0f0f0f ;
1466
1469
auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -1470,7 +1473,7 @@ static inline BTLA_CODE decompress_s5_s8(utils::bit4x2* bit4ptr, utils::bit1x8*
1470
1473
const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
1471
1474
const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
1472
1475
const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1473
- int elt_pad = utils::padto_le (unpack_elt, VElt);
1476
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
1474
1477
for (; i < elt_pad; i += VElt) {
1475
1478
auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
1476
1479
auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -1760,7 +1763,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
1760
1763
int8_t * dstptr, size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
1761
1764
int constexpr VBits = 256 ;
1762
1765
int constexpr VElt = VBits / 8 ;
1763
- int i = 0 ;
1766
+ size_t i = 0 ;
1764
1767
int constexpr FullRange = 1 << (7 - 1 );
1765
1768
uint32_t mask = 0x0f0f0f0f ;
1766
1769
auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -1777,7 +1780,7 @@ static inline BTLA_CODE decompress_s7_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
1777
1780
const __m256i bit1Mask = _mm256_set1_epi32 (0x0F );
1778
1781
const __m256i bit1Shift_1 = _mm256_set_epi32 (28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
1779
1782
const __m256i bit1Shift_2 = _mm256_set1_epi32 ((1 << 23 ) + (1 << 16 ) + (1 << 9 ) + (1 << 2 ));
1780
- int elt_pad = utils::padto_le (unpack_elt, VElt);
1783
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
1781
1784
for (; i < elt_pad; i += VElt) {
1782
1785
auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
1783
1786
auto vb1 = unpack_1bits (bit1ptr + i / 8 , bit1Shift_1, bit1Mask, bit1Shift_2, highMask);
@@ -2035,7 +2038,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
2035
2038
size_t unpack_elt, int8_t * tmp, size_t tmpsize) {
2036
2039
int constexpr VBits = 256 ;
2037
2040
int constexpr VElt = VBits / 8 ;
2038
- int i = 0 ;
2041
+ size_t i = 0 ;
2039
2042
int constexpr FullRange = 1 << (6 - 1 );
2040
2043
uint32_t mask = 0x0f0f0f0f ;
2041
2044
auto vmask = _mm256_set1_epi32 (*reinterpret_cast <int *>(&mask));
@@ -2047,7 +2050,7 @@ static inline BTLA_CODE decompress_s6_s8(utils::bit4x2* bit4ptr, utils::bit2x4*
2047
2050
auto vsfhl_mask_y = _mm256_set_epi8 (15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 , 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 ,
2048
2051
13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
2049
2052
auto vorder_y = _mm256_set_epi32 (1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 );
2050
- int elt_pad = utils::padto_le (unpack_elt, VElt);
2053
+ size_t elt_pad = utils::padto_le (unpack_elt, VElt);
2051
2054
for (; i < elt_pad; i += VElt) {
2052
2055
auto vout = unpack_4bits (bit4ptr + i / 2 , vmask);
2053
2056
auto vb1 = unpack_2bits (bit2ptr + i / 4 , vshift_y, vmask0, vsfhl_mask_y, vorder_y);
@@ -3474,8 +3477,8 @@ inline __m256 exp_ps_0_1(const __m256 x) {
3474
3477
static const auto log2e = _mm256_set1_ps (v_log2e);
3475
3478
static const auto half = _mm256_set1_ps (.5f );
3476
3479
3477
- static const auto upper_bound = _mm256_set1_ps (88.722838 ); // log(max_positive_float)
3478
- static const auto lower_bound = _mm256_set1_ps (-87.336549 ); // log(min_positive_float)
3480
+ static const auto upper_bound = _mm256_set1_ps (88 .722838f ); // log(max_positive_float)
3481
+ static const auto lower_bound = _mm256_set1_ps (-87 .336549f ); // log(min_positive_float)
3479
3482
__m256 x1 = _mm256_min_ps (x, upper_bound);
3480
3483
x1 = _mm256_max_ps (x1, lower_bound);
3481
3484
0 commit comments