Skip to content

Commit a64668d

Browse files
DiamonDinoiaserge-sans-paille
authored andcommitted
initial-commit
1 parent d92c6d4 commit a64668d

File tree

8 files changed

+99
-0
lines changed

8 files changed

+99
-0
lines changed

include/xsimd/arch/common/xsimd_common_arithmetic.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,18 @@ namespace xsimd
127127
return { res_r, res_i };
128128
}
129129

130+
// fmas
131+
template <class A, class T>
132+
XSIMD_INLINE batch<T, A> fmas(batch<T, A> const& x, batch<T, A> const& y, batch<T, A> const& z, requires_arch<common>) noexcept
133+
{
134+
struct even_lane
135+
{
136+
static constexpr bool get(unsigned const i, unsigned) noexcept { return (i & 1u) == 0; }
137+
};
138+
const auto mask = make_batch_bool_constant<T, even_lane, A>();
139+
return fma(x, y, select(mask, neg(z), z));
140+
}
141+
130142
// hadd
131143
template <class A, class T, class /*=typename std::enable_if<std::is_integral<T>::value, void>::type*/>
132144
XSIMD_INLINE T hadd(batch<T, A> const& self, requires_arch<common>) noexcept

include/xsimd/arch/common/xsimd_common_details.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ namespace xsimd
4747
template <class T, class A>
4848
XSIMD_INLINE batch<T, A> fms(batch<T, A> const& x, batch<T, A> const& y, batch<T, A> const& z) noexcept;
4949
template <class T, class A>
50+
XSIMD_INLINE batch<T, A> fmas(batch<T, A> const& x, batch<T, A> const& y, batch<T, A> const& z) noexcept;
51+
template <class T, class A>
5052
XSIMD_INLINE batch<T, A> frexp(const batch<T, A>& x, const batch<as_integer_t<T>, A>& e) noexcept;
5153
template <class T, class A, uint64_t... Coefs>
5254
XSIMD_INLINE batch<T, A> horner(const batch<T, A>& self) noexcept;

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,18 @@ namespace xsimd
902902
{
903903
return _mm512_fmsub_pd(x, y, z);
904904
}
905+
// fmas
906+
template <class A>
907+
XSIMD_INLINE batch<float, A> fmas(batch<float, A> const& x, batch<float, A> const& y, batch<float, A> const& z, requires_arch<avx512f>) noexcept
908+
{
909+
return _mm512_fmaddsub_ps(x, y, z);
910+
}
911+
912+
template <class A>
913+
XSIMD_INLINE batch<double, A> fmas(batch<double, A> const& x, batch<double, A> const& y, batch<double, A> const& z, requires_arch<avx512f>) noexcept
914+
{
915+
return _mm512_fmaddsub_pd(x, y, z);
916+
}
905917

906918
// from bool
907919
template <class A, class T>

include/xsimd/arch/xsimd_fma3_avx.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ namespace xsimd
7373
return _mm256_fmsub_pd(x, y, z);
7474
}
7575

76+
// fmas
77+
template <class A>
78+
XSIMD_INLINE batch<float, A> fmas(batch<float, A> const& x, batch<float, A> const& y, batch<float, A> const& z, requires_arch<fma3<avx>>) noexcept
79+
{
80+
return _mm256_fmaddsub_ps(x, y, z);
81+
}
82+
83+
template <class A>
84+
XSIMD_INLINE batch<double, A> fmas(batch<double, A> const& x, batch<double, A> const& y, batch<double, A> const& z, requires_arch<fma3<avx>>) noexcept
85+
{
86+
return _mm256_fmaddsub_pd(x, y, z);
87+
}
88+
7689
}
7790

7891
}

include/xsimd/arch/xsimd_fma3_sse.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,24 @@ namespace xsimd
7171
{
7272
return _mm_fmsub_pd(x, y, z);
7373
}
74+
// fms
75+
template <class A>
76+
XSIMD_INLINE batch<float, A> fmas(batch<float, A> const& x,
77+
batch<float, A> const& y,
78+
batch<float, A> const& z,
79+
requires_arch<fma3<sse4_2>>) noexcept
80+
{
81+
return _mm_fmaddsub_ps(x, y, z);
82+
}
83+
84+
template <class A>
85+
XSIMD_INLINE batch<double, A> fmas(batch<double, A> const& x,
86+
batch<double, A> const& y,
87+
batch<double, A> const& z,
88+
requires_arch<fma3<sse4_2>>) noexcept
89+
{
90+
return _mm_fmaddsub_pd(x, y, z);
91+
}
7492

7593
}
7694

include/xsimd/arch/xsimd_fma4.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ namespace xsimd
7272
{
7373
return _mm_msub_pd(x, y, z);
7474
}
75+
76+
// fmas
77+
template <class A>
78+
XSIMD_INLINE batch<float, A> fmas(batch<float, A> const& x, batch<float, A> const& y, batch<float, A> const& z, requires_arch<fma4>) noexcept
79+
{
80+
return _mm_maddsub_ps(x, y, z);
81+
}
82+
83+
template <class A>
84+
XSIMD_INLINE batch<double, A> fmas(batch<double, A> const& x, batch<double, A> const& y, batch<double, A> const& z, requires_arch<fma4>) noexcept
85+
{
86+
return _mm_maddsub_pd(x, y, z);
87+
}
7588
}
7689

7790
}

include/xsimd/types/xsimd_api.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,21 @@ namespace xsimd
991991
return kernel::fnms<A>(x, y, z, A {});
992992
}
993993

994+
/**
995+
* @ingroup batch_arithmetic
996+
*
997+
* Computes <tt>-(x*y) - z</tt> in a single instruction when possible.
998+
* @param x a batch of integer or floating point values.
999+
* @param y a batch of integer or floating point values.
1000+
* @param z a batch of integer or floating point values.
1001+
* @return a batch where each even-indexed element is computed as <tt>x * y - z</tt> and each odd-indexed element as <tt>x * y + z</tt>
1002+
*/
1003+
template <class T, class A>
1004+
XSIMD_INLINE batch<T, A> fmas(batch<T, A> const& x, batch<T, A> const& y, batch<T, A> const& z) noexcept
1005+
{
1006+
detail::static_check_supported_config<T, A>();
1007+
return kernel::fmas<A>(x, y, z, A {});
1008+
}
9941009
/**
9951010
* @ingroup batch_fp
9961011
*

test/test_batch.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,20 @@ struct batch_test
711711
INFO("fnms");
712712
CHECK_BATCH_EQ(res, expected);
713713
}
714+
// fmas
715+
{
716+
array_type expected;
717+
for (std::size_t i = 0; i < expected.size(); ++i)
718+
{
719+
// even lanes: x*y - z, odd lanes: x*y + z
720+
expected[i] = (i & 1u) == 0
721+
? lhs[i] * rhs[i] - rhs[i]
722+
: lhs[i] * rhs[i] + rhs[i];
723+
}
724+
batch_type res = fmas(batch_lhs(), batch_rhs(), batch_rhs());
725+
INFO("fmas");
726+
CHECK_BATCH_EQ(res, expected);
727+
}
714728
}
715729

716730
void test_abs() const

0 commit comments

Comments
 (0)