Skip to content

Commit 54aa8e7

Browse files
Merge pull request #627 from xtensor-stack/feature/float-to-uint-conversion
Provide some conversion operators for float -> uint32
2 parents 26063cb + be1eb18 commit 54aa8e7

File tree

4 files changed

+83
-15
lines changed

4 files changed

+83
-15
lines changed

include/xsimd/arch/xsimd_avx.hpp

+18
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ namespace xsimd
491491
return get_half_complex_d<1>(self.real(), self.imag());
492492
}
493493
}
494+
494495
// convert
495496
namespace detail
496497
{
@@ -499,11 +500,28 @@ namespace xsimd
499500
{
500501
return _mm256_cvtepi32_ps(self);
501502
}
503+
504+
template <class A>
505+
inline batch<float, A> fast_cast(batch<uint32_t, A> const& v, batch<float, A> const&, requires_arch<avx>)
506+
{
507+
// see https://stackoverflow.com/questions/34066228/how-to-perform-uint32-float-conversion-with-sse
508+
__m256i msk_lo = _mm256_set1_epi32(0xFFFF);
509+
__m256 cnst65536f = _mm256_set1_ps(65536.0f);
510+
511+
__m256i v_lo = bitwise_and(batch<uint32_t, A>(v), batch<uint32_t, A>(msk_lo)); /* extract the 16 lowest significant bits of self */
512+
__m256i v_hi = bitwise_rshift(batch<uint32_t, A>(v), 16, avx {}); /* 16 most significant bits of v */
513+
__m256 v_lo_flt = _mm256_cvtepi32_ps(v_lo); /* No rounding */
514+
__m256 v_hi_flt = _mm256_cvtepi32_ps(v_hi); /* No rounding */
515+
v_hi_flt = _mm256_mul_ps(cnst65536f, v_hi_flt); /* No rounding */
516+
return _mm256_add_ps(v_hi_flt, v_lo_flt); /* Rounding may occur here, mul and add may fuse to fma for haswell and newer */
517+
}
518+
502519
template <class A>
503520
inline batch<int32_t, A> fast_cast(batch<float, A> const& self, batch<int32_t, A> const&, requires_arch<avx>)
504521
{
505522
return _mm256_cvttps_epi32(self);
506523
}
524+
507525
}
508526

509527
// div

include/xsimd/arch/xsimd_avx2.hpp

+20
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,26 @@ namespace xsimd
242242
__m256d tmp1 = _mm256_permute4x64_pd(self.imag(), _MM_SHUFFLE(3, 2, 2, 0));
243243
return _mm256_blend_pd(tmp0, tmp1, 10);
244244
}
245+
// convert
246+
namespace detail
247+
{
248+
249+
template <class A>
250+
inline batch<float, A> fast_cast(batch<uint32_t, A> const& v, batch<float, A> const&, requires_arch<avx2>)
251+
{
252+
// see https://stackoverflow.com/questions/34066228/how-to-perform-uint32-float-conversion-with-sse
253+
__m256i msk_lo = _mm256_set1_epi32(0xFFFF);
254+
__m256 cnst65536f = _mm256_set1_ps(65536.0f);
255+
256+
__m256i v_lo = _mm256_and_si256(v, msk_lo); /* extract the 16 lowest significant bits of self */
257+
__m256i v_hi = _mm256_srli_epi32(v, 16); /* 16 most significant bits of v */
258+
__m256 v_lo_flt = _mm256_cvtepi32_ps(v_lo); /* No rounding */
259+
__m256 v_hi_flt = _mm256_cvtepi32_ps(v_hi); /* No rounding */
260+
v_hi_flt = _mm256_mul_ps(cnst65536f, v_hi_flt); /* No rounding */
261+
return _mm256_add_ps(v_hi_flt, v_lo_flt); /* Rounding may occur here, mul and add may fuse to fma for haswell and newer */
262+
}
263+
264+
}
245265

246266
// eq
247267
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>

include/xsimd/arch/xsimd_avx512f.hpp

+28-15
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,34 @@ namespace xsimd
625625
return _mm512_roundscale_pd(self, _MM_FROUND_TO_POS_INF);
626626
}
627627

628+
// convert
629+
namespace detail
630+
{
631+
template <class A>
632+
inline batch<float, A> fast_cast(batch<int32_t, A> const& self, batch<float, A> const&, requires_arch<avx512f>)
633+
{
634+
return _mm512_cvtepi32_ps(self);
635+
}
636+
637+
template <class A>
638+
inline batch<int32_t, A> fast_cast(batch<float, A> const& self, batch<int32_t, A> const&, requires_arch<avx512f>)
639+
{
640+
return _mm512_cvttps_epi32(self);
641+
}
642+
643+
template <class A>
644+
inline batch<float, A> fast_cast(batch<uint32_t, A> const& self, batch<float, A> const&, requires_arch<avx512f>)
645+
{
646+
return _mm512_cvtepu32_ps(self);
647+
}
648+
649+
template <class A>
650+
batch<uint32_t, A> fast_cast(batch<float, A> const& self, batch<uint32_t, A> const&, requires_arch<avx512f>)
651+
{
652+
return _mm512_cvttps_epu32(self);
653+
}
654+
}
655+
628656
namespace detail
629657
{
630658
// complex_low
@@ -656,21 +684,6 @@ namespace xsimd
656684
}
657685
}
658686

659-
// convert
660-
namespace detail
661-
{
662-
template <class A>
663-
inline batch<float, A> fast_cast(batch<int32_t, A> const& self, batch<float, A> const&, requires_arch<avx512f>)
664-
{
665-
return _mm512_cvtepi32_ps(self);
666-
}
667-
template <class A>
668-
inline batch<int32_t, A> fast_cast(batch<float, A> const& self, batch<int32_t, A> const&, requires_arch<avx512f>)
669-
{
670-
return _mm512_cvttps_epi32(self);
671-
}
672-
}
673-
674687
// div
675688
template <class A>
676689
inline batch<float, A> div(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx512f>)

include/xsimd/arch/xsimd_sse2.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,28 @@ namespace xsimd
483483
{
484484
return _mm_cvtepi32_ps(self);
485485
}
486+
487+
template <class A>
488+
inline batch<float, A> fast_cast(batch<uint32_t, A> const& v, batch<float, A> const&, requires_arch<sse2>)
489+
{
490+
// see https://stackoverflow.com/questions/34066228/how-to-perform-uint32-float-conversion-with-sse
491+
__m128i msk_lo = _mm_set1_epi32(0xFFFF);
492+
__m128 cnst65536f = _mm_set1_ps(65536.0f);
493+
494+
__m128i v_lo = _mm_and_si128(v, msk_lo); /* extract the 16 lowest significant bits of self */
495+
__m128i v_hi = _mm_srli_epi32(v, 16); /* 16 most significant bits of v */
496+
__m128 v_lo_flt = _mm_cvtepi32_ps(v_lo); /* No rounding */
497+
__m128 v_hi_flt = _mm_cvtepi32_ps(v_hi); /* No rounding */
498+
v_hi_flt = _mm_mul_ps(cnst65536f, v_hi_flt); /* No rounding */
499+
return _mm_add_ps(v_hi_flt, v_lo_flt); /* Rounding may occur here, mul and add may fuse to fma for haswell and newer */
500+
}
501+
486502
template <class A>
487503
inline batch<int32_t, A> fast_cast(batch<float, A> const& self, batch<int32_t, A> const&, requires_arch<sse2>)
488504
{
489505
return _mm_cvttps_epi32(self);
490506
}
507+
491508
}
492509

493510
// eq

0 commit comments

Comments
 (0)