Skip to content

Commit 3294464

Browse files
jatinchowdhury18serge-sans-paille
authored andcommitted
Add mixed-complex implementations of xsimd::pow()
1 parent 3b05fd6 commit 3294464

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

include/xsimd/arch/generic/xsimd_generic_math.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,26 @@ namespace xsimd
20302030
return select(absa == ze, cplx_batch(ze), cplx_batch(r * sincosTheta.second, r * sincosTheta.first));
20312031
}
20322032

2033+
template <class A, class T>
2034+
inline batch<std::complex<T>, A> pow(const batch<std::complex<T>, A>& a, const batch<T, A>& z, requires_arch<generic>) noexcept
2035+
{
2036+
using cplx_batch = batch<std::complex<T>, A>;
2037+
2038+
auto absa = abs(a);
2039+
auto arga = arg(a);
2040+
auto r = pow(absa, z);
2041+
2042+
auto theta = z * arga;
2043+
auto sincosTheta = xsimd::sincos(theta);
2044+
return select(absa == 0, cplx_batch(0), cplx_batch(r * sincosTheta.second, r * sincosTheta.first));
2045+
}
2046+
2047+
template <class A, class T>
2048+
inline batch<std::complex<T>, A> pow(const batch<T, A>& a, const batch<std::complex<T>, A>& z, requires_arch<generic>) noexcept
2049+
{
2050+
return pow(batch<std::complex<T>, A> { a, batch<T, A> {} }, z);
2051+
}
2052+
20332053
// reciprocal
20342054
template <class T, class A, class = typename std::enable_if<std::is_floating_point<T>::value, void>::type>
20352055
XSIMD_INLINE batch<T, A> reciprocal(batch<T, A> const& self,

include/xsimd/types/xsimd_api.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,38 @@ namespace xsimd
16961696
return kernel::pow<A>(x, y, A {});
16971697
}
16981698

1699+
/**
1700+
* @ingroup batch_math
1701+
*
1702+
* Computes the value of the batch \c x raised to the power
1703+
* \c y.
1704+
* @param x batch of complex floating point values.
1705+
* @param y batch of floating point values.
1706+
* @return \c x raised to the power \c y.
1707+
*/
1708+
template <class T, class A>
1709+
XSIMD_INLINE batch<std::complex<T>, A> pow(batch<std::complex<T>, A> const& x, batch<T, A> const& y) noexcept
1710+
{
1711+
detail::static_check_supported_config<T, A>();
1712+
return kernel::pow<A>(x, y, A {});
1713+
}
1714+
1715+
/**
1716+
* @ingroup batch_math
1717+
*
1718+
* Computes the value of the batch \c x raised to the power
1719+
* \c y.
1720+
* @param x batch of complex floating point values.
1721+
* @param y batch of floating point values.
1722+
* @return \c x raised to the power \c y.
1723+
*/
1724+
template <class T, class A>
1725+
XSIMD_INLINE batch<std::complex<T>, A> pow(batch<T, A> const& x, batch<std::complex<T>, A> const& y) noexcept
1726+
{
1727+
detail::static_check_supported_config<T, A>();
1728+
return kernel::pow<A>(x, y, A {});
1729+
}
1730+
16991731
/**
17001732
* @ingroup batch_math
17011733
*

test/test_complex_power.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct complex_power_test
2626
using real_vector_type = std::vector<real_value_type>;
2727

2828
size_t nb_input;
29+
real_vector_type lhs_p;
2930
vector_type lhs_nn;
3031
vector_type lhs_pn;
3132
vector_type lhs_np;
@@ -37,6 +38,7 @@ struct complex_power_test
3738
complex_power_test()
3839
{
3940
nb_input = 10000 * size;
41+
lhs_p.resize(nb_input);
4042
lhs_nn.resize(nb_input);
4143
lhs_pn.resize(nb_input);
4244
lhs_np.resize(nb_input);
@@ -46,6 +48,7 @@ struct complex_power_test
4648
{
4749
real_value_type real = (real_value_type(i) / 4 + real_value_type(1.2) * std::sqrt(real_value_type(i + 0.25))) / 100;
4850
real_value_type imag = (real_value_type(i) / 7 + real_value_type(1.7) * std::sqrt(real_value_type(i + 0.37))) / 100;
51+
lhs_p[i] = real;
4952
lhs_nn[i] = value_type(-real, -imag);
5053
lhs_pn[i] = value_type(real, -imag);
5154
lhs_np[i] = value_type(-real, imag);
@@ -110,6 +113,42 @@ struct complex_power_test
110113
CHECK_EQ(diff, 0);
111114
}
112115

116+
void test_pow_real_complex()
117+
{
118+
std::transform(lhs_p.cbegin(), lhs_p.cend(), lhs_pp.cbegin(), expected.begin(),
119+
[](const real_value_type& l, const value_type& r)
120+
{ using std::pow; return pow(l, r); });
121+
batch_type rhs_in, out;
122+
real_batch_type lhs_in;
123+
for (size_t i = 0; i < nb_input; i += size)
124+
{
125+
detail::load_batch(lhs_in, lhs_p, i);
126+
detail::load_batch(rhs_in, lhs_pp, i);
127+
out = pow(lhs_in, rhs_in);
128+
detail::store_batch(out, res, i);
129+
}
130+
size_t diff = detail::get_nb_diff_near(res, expected, std::numeric_limits<real_value_type>::epsilon());
131+
CHECK_EQ(diff, 0);
132+
}
133+
134+
void test_pow_complex_real()
135+
{
136+
std::transform(lhs_pp.cbegin(), lhs_pp.cend(), lhs_p.cbegin(), expected.begin(),
137+
[](const value_type& l, const real_value_type& r)
138+
{ using std::pow; return pow(l, r); });
139+
batch_type rhs_in, out;
140+
real_batch_type lhs_in;
141+
for (size_t i = 0; i < nb_input; i += size)
142+
{
143+
detail::load_batch(lhs_in, lhs_p, i);
144+
detail::load_batch(rhs_in, lhs_pp, i);
145+
out = pow(lhs_in, rhs_in);
146+
detail::store_batch(out, res, i);
147+
}
148+
size_t diff = detail::get_nb_diff_near(res, expected, std::numeric_limits<real_value_type>::epsilon());
149+
CHECK_EQ(diff, 0);
150+
}
151+
113152
void test_sqrt_nn()
114153
{
115154
std::transform(lhs_nn.cbegin(), lhs_nn.cend(), expected.begin(),
@@ -193,6 +232,16 @@ TEST_CASE_TEMPLATE("[complex power]", B, BATCH_COMPLEX_TYPES)
193232
Test.test_pow();
194233
}
195234

235+
SUBCASE("pow real complex")
236+
{
237+
Test.test_pow_real_complex();
238+
}
239+
240+
SUBCASE("pow complex real")
241+
{
242+
Test.test_pow_complex_real();
243+
}
244+
196245
SUBCASE("sqrt_nn")
197246
{
198247
Test.test_sqrt_nn();

0 commit comments

Comments
 (0)