diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index 7dbfb6618c..3af8f0f4af 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace acos { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,9 +74,10 @@ template struct AcosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ @@ -106,12 +109,10 @@ template struct AcosFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = - exprm_ns::log(exprm_ns::complex(in)); + sycl_complexT log_z = exprm_ns::log(z); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); @@ -119,7 +120,7 @@ template struct AcosFunctor } /* ordinary cases */ - return exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + return exprm_ns::acos(z); // acos(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index a81ff3da99..2bcd3dbd4e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace acosh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -77,17 +79,19 @@ template struct AcoshFunctor * where the sign is chosen so Re(acosh(in)) >= 0. * So, we first calculate acos(in) and then acosh(in). */ - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); - resT acos_in; + sycl_complexT acos_z; if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ if (std::isinf(y)) { - acos_in = resT{q_nan, -y}; + acos_z = resT{q_nan, -y}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } else if (std::isnan(y)) { @@ -95,15 +99,15 @@ template struct AcoshFunctor constexpr realT inf = std::numeric_limits::infinity(); if (std::isinf(x)) { - acos_in = resT{q_nan, -inf}; + acos_z = resT{q_nan, -inf}; } /* acos(0 + I*NaN) = Pi/2 + I*NaN with inexact */ else if (x == realT(0)) { const realT pi_half = sycl::atan(realT(1)) * 2; - acos_in = resT{pi_half, q_nan}; + acos_z = resT{pi_half, q_nan}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } @@ -113,23 +117,21 @@ template struct AcoshFunctor * For large x or y including acos(+-Inf + I*+-Inf) */ if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = typename exprm_ns::complex; - const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in)); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const sycl_complexT log_z = exprm_ns::log(z); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); - acos_in = resT{rx, (sycl::signbit(y)) ? ry : -ry}; + acos_z = resT{rx, (sycl::signbit(y)) ? ry : -ry}; } else { /* ordinary cases */ - acos_in = - exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + acos_z = exprm_ns::acos(z); // acos(z); } /* Now we calculate acosh(z) */ - const realT rx = std::real(acos_in); - const realT ry = std::imag(acos_in); + const realT rx = exprm_ns::real(acos_z); + const realT ry = exprm_ns::imag(acos_z); /* acosh(NaN + I*NaN) = NaN + I*NaN */ if (std::isnan(rx) && std::isnan(ry)) { @@ -145,7 +147,7 @@ template struct AcoshFunctor return resT{ry, ry}; } /* ordinary cases */ - const realT res_im = sycl::copysign(rx, std::imag(in)); + const realT res_im = sycl::copysign(rx, exprm_ns::imag(z)); return resT{sycl::fabs(ry), res_im}; } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 476e7b52b9..e7b7f0c0e7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -50,8 +50,10 @@ namespace add { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct AddFunctor { @@ -69,21 +71,22 @@ template struct AddFunctor using rT1 = typename argT1::value_type; using rT2 = typename argT2::value_type; - return exprm_ns::complex(in1) + exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) + + su_ns::sycl_complex_t(in2); } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { using rT1 = typename argT1::value_type; - return exprm_ns::complex(in1) + in2; + return su_ns::sycl_complex_t(in1) + in2; } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { using rT2 = typename argT2::value_type; - return in1 + exprm_ns::complex(in2); + return in1 + su_ns::sycl_complex_t(in2); } else { return in1 + in2; @@ -460,7 +463,21 @@ template struct AddInplaceFunctor using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; - void operator()(resT &res, const argT &in) { res += in; } + void operator()(resT &res, const argT &in) + { + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using rT1 = typename resT::value_type; + using rT2 = typename argT::value_type; + + auto tmp = su_ns::sycl_complex_t(res); + tmp += su_ns::sycl_complex_t(in); + + res = resT(tmp); + } + else { + res += in; + } + } template void operator()(sycl::vec &res, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 726f90ba81..501a73765d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace angle { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,7 +73,7 @@ template struct AngleFunctor { using rT = typename argT::value_type; - return exprm_ns::arg(exprm_ns::complex(in)); // arg(in); + return exprm_ns::arg(su_ns::sycl_complex_t(in)); // arg(in); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 70b48895b4..9920bca56c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace asin { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -80,8 +82,10 @@ template struct AsinFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is asin(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -120,26 +124,24 @@ template struct AsinFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - const sycl_complexT z{x, y}; + const sycl_complexT z1{x, y}; realT wx, wy; if (!sycl::signbit(x)) { - const auto log_z = exprm_ns::log(z); - wx = log_z.real() + sycl::log(realT(2)); - wy = log_z.imag(); + const auto log_z1 = exprm_ns::log(z1); + wx = log_z1.real() + sycl::log(realT(2)); + wy = log_z1.imag(); } else { - const auto log_mz = exprm_ns::log(-z); - wx = log_mz.real() + sycl::log(realT(2)); - wy = log_mz.imag(); + const auto log_mz1 = exprm_ns::log(-z1); + wx = log_mz1.real() + sycl::log(realT(2)); + wy = log_mz1.imag(); } const realT asinh_re = sycl::copysign(wx, x); const realT asinh_im = sycl::copysign(wy, y); return resT{asinh_im, asinh_re}; } /* ordinary cases */ - return exprm_ns::asin( - exprm_ns::complex(in)); // sycl::asin(in); + return exprm_ns::asin(z); // sycl::asin(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 420ba3246c..ea686fccc3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace asinh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,9 +74,10 @@ template struct AsinhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -109,12 +112,10 @@ template struct AsinhFunctor realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = (sycl::signbit(x)) - ? exprm_ns::log(sycl_complexT(-in)) - : exprm_ns::log(sycl_complexT(in)); - realT wx = log_in.real() + sycl::log(realT(2)); - realT wy = log_in.imag(); + sycl_complexT log_in = + (sycl::signbit(x)) ? exprm_ns::log(-z) : exprm_ns::log(z); + realT wx = exprm_ns::real(log_in) + sycl::log(realT(2)); + realT wy = exprm_ns::imag(log_in); const realT res_re = sycl::copysign(wx, x); const realT res_im = sycl::copysign(wy, y); @@ -122,7 +123,7 @@ template struct AsinhFunctor } /* ordinary cases */ - return exprm_ns::asinh(exprm_ns::complex(in)); // asinh(in); + return exprm_ns::asinh(z); // asinh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 29c4941d76..2728616841 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace atan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::kernels::vec_size_utils::ContigHyperparameterSetDefault; using dpctl::tensor::kernels::vec_size_utils::UnaryContigHyperparameterSetEntry; @@ -83,8 +85,11 @@ template struct AtanFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is atan(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); + if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)*0 + I*+-Pi/2 */ if (std::isinf(y)) { @@ -132,7 +137,7 @@ template struct AtanFunctor return resT{atanh_im, atanh_re}; } /* ordinary cases */ - return exprm_ns::atan(exprm_ns::complex(in)); // atan(in); + return exprm_ns::atan(z); // atan(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 39f11e0f90..eee287823d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace atanh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,8 +75,10 @@ template struct AtanhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)0 + I*+-PI/2 */ @@ -123,7 +127,7 @@ template struct AtanhFunctor return resT{res_re, res_im}; } /* ordinary cases */ - return exprm_ns::atanh(exprm_ns::complex(in)); // atanh(in); + return exprm_ns::atanh(z); // atanh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp index afa83a64cb..c82e986f27 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp @@ -27,7 +27,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" namespace dpctl { @@ -38,6 +38,9 @@ namespace kernels namespace detail { +namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; + template realT cabs(std::complex const &z) { // Special values for cabs( x + y * 1j): @@ -51,8 +54,10 @@ template realT cabs(std::complex const &z) // * If x is a finite number and y is NaN, the result is NaN. // * If x is NaN and y is NaN, the result is NaN. - const realT x = std::real(z); - const realT y = std::imag(z); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT _z = su_ns::sycl_complex_t(z); + const realT x = exprm_ns::real(_z); + const realT y = exprm_ns::imag(_z); constexpr realT q_nan = std::numeric_limits::quiet_NaN(); constexpr realT p_inf = std::numeric_limits::infinity(); @@ -60,11 +65,8 @@ template realT cabs(std::complex const &z) const realT res = std::isinf(x) ? p_inf - : ((std::isinf(y) - ? p_inf - : ((std::isnan(x) - ? q_nan - : exprm_ns::abs(exprm_ns::complex(z)))))); + : ((std::isinf(y) ? p_inf + : ((std::isnan(x) ? q_nan : exprm_ns::abs(_z))))); return res; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 19a95df5a1..61859c9efe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace conj { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct ConjFunctor if constexpr (is_complex::value) { using rT = typename argT::value_type; - return exprm_ns::conj(exprm_ns::complex(in)); // conj(in); + return exprm_ns::conj(su_ns::sycl_complex_t(in)); // conj(in); } else { if constexpr (!std::is_same_v) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 5940315c62..e1da401886 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace cos { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,30 +74,31 @@ template struct CosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT z_re = exprm_ns::real(z); + const realT z_im = exprm_ns::imag(z); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); - - const bool in_re_finite = std::isfinite(in_re); - const bool in_im_finite = std::isfinite(in_im); + const bool z_re_finite = std::isfinite(z_re); + const bool z_im_finite = std::isfinite(z_im); /* * Handle the nearly-non-exceptional cases where * real and imaginary parts of input are finite. */ - if (in_re_finite && in_im_finite) { - return exprm_ns::cos(exprm_ns::complex(in)); // cos(in); + if (z_re_finite && z_im_finite) { + return exprm_ns::cos(z); // cos(z); } /* - * since cos(in) = cosh(I * in), for special cases, - * we return cosh(I * in). + * since cos(z) = cosh(I * z), for special cases, + * we return cosh(I * z). */ - const realT x = -in_im; - const realT y = in_re; + const realT x = -z_im; + const realT y = z_re; - const bool xfinite = in_im_finite; - const bool yfinite = in_re_finite; + const bool xfinite = z_im_finite; + const bool yfinite = z_re_finite; /* * cosh(+-0 +- I Inf) = dNaN + I sign(d(+-0, dNaN))0. * The sign of 0 in the result is unspecified. Choice = normally diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 59468428d1..4b841c3486 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace cosh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,8 +75,10 @@ template struct CoshFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -84,8 +88,7 @@ template struct CoshFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::cosh( - exprm_ns::complex(in)); // cosh(in); + return exprm_ns::cosh(z); // cosh(z); } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index a53f6412de..9f3f09791e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -49,6 +49,7 @@ namespace equal { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -71,8 +72,8 @@ template struct EqualFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) == - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) == + su_ns::sycl_complex_t(in2); } else { if constexpr (std::is_integral_v && diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 00f8213251..a35481996f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace exp { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,12 +74,13 @@ template struct ExpFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp( - exprm_ns::complex(in)); // exp(in); + return exprm_ns::exp(z); // exp(z); } else { return resT{q_nan, q_nan}; @@ -86,7 +89,7 @@ template struct ExpFunctor else if (std::isnan(x)) { /* x is nan */ if (y == realT(0)) { - return resT{in}; + return resT{z}; } else { return resT{x, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index 22291101ca..93f54970af 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace exp2 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,15 +73,18 @@ template struct Exp2Functor if constexpr (is_complex::value) { using realT = typename argT::value_type; - const argT tmp = in * sycl::log(realT(2)); - constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(tmp); - const realT y = std::imag(tmp); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); + + const sycl_complexT tmp = z * sycl::log(realT(2)); + if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp(exprm_ns::complex(tmp)); + return exprm_ns::exp(tmp); } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index d1d64f4904..973da46d63 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -31,6 +31,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +51,9 @@ namespace expm1 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,8 +76,10 @@ template struct Expm1Functor using realT = typename argT::value_type; // expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 + // I*exp(x)*sin(y) - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // special cases if (std::isinf(x)) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index c5e0feea12..c1b52f1d14 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,6 +31,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +51,9 @@ namespace imag { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::is_complex_v; @@ -72,7 +75,9 @@ template struct ImagFunctor resT operator()(const argT &in) const { if constexpr (is_complex_v) { - return std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = typename su_ns::sycl_complex_t; + return exprm_ns::imag(sycl_complexT(in)); } else { static_assert(std::is_same_v); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index b0651a4d8b..83deac1a42 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -30,6 +30,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -46,7 +47,9 @@ namespace isfinite { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -69,8 +72,11 @@ template struct IsFiniteFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isfinite = std::isfinite(std::real(in)); - const bool imag_isfinite = std::isfinite(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const bool real_isfinite = std::isfinite(exprm_ns::real(z)); + const bool imag_isfinite = std::isfinite(exprm_ns::imag(z)); return (real_isfinite && imag_isfinite); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index ec78746143..1c0a2875f4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,6 +30,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -48,7 +49,9 @@ namespace isinf { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -69,8 +72,11 @@ template struct IsInfFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isinf = std::isinf(std::real(in)); - const bool imag_isinf = std::isinf(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const bool real_isinf = std::isinf(exprm_ns::real(z)); + const bool imag_isinf = std::isinf(exprm_ns::imag(z)); return (real_isinf || imag_isinf); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index fbf6ef9383..1317bdc945 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,6 +29,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -47,7 +48,9 @@ namespace isnan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -70,8 +73,11 @@ template struct IsNanFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isnan = sycl::isnan(std::real(in)); - const bool imag_isnan = sycl::isnan(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const bool real_isnan = sycl::isnan(exprm_ns::real(z)); + const bool imag_isnan = sycl::isnan(exprm_ns::imag(z)); return (real_isnan || imag_isnan); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index 84471a5ef4..c33af596d8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace log { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,7 +73,7 @@ template struct LogFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return exprm_ns::log(exprm_ns::complex(in)); // log(in); + return exprm_ns::log(su_ns::sycl_complex_t(in)); // log(in); } else { return sycl::log(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index d308c85ac9..30868237dc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace log10 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -75,7 +77,7 @@ template struct Log10Functor if constexpr (is_complex::value) { using realT = typename argT::value_type; // return (log(in) / log(realT{10})); - return exprm_ns::log(exprm_ns::complex(in)) / + return exprm_ns::log(su_ns::sycl_complex_t(in)) / sycl::log(realT{10}); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index b8d993dd94..ee29b3ad4f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -30,6 +30,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +50,9 @@ namespace log1p { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -78,8 +81,11 @@ template struct Log1pFunctor // = log1p(x^2 + 2x + y^2) / 2 // + I * atan2(y, x + 1) using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // imaginary part of result const realT res_im = sycl::atan2(y, x + 1); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 42c837cfa3..4b708b2b93 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace log2 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -76,7 +78,7 @@ template struct Log2Functor using realT = typename argT::value_type; // log(in) / log(realT{2}); - return exprm_ns::log(exprm_ns::complex(in)) / + return exprm_ns::log(su_ns::sycl_complex_t(in)) / sycl::log(realT{2}); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index ca24383b44..30bd058252 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -51,8 +51,10 @@ namespace multiply { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct MultiplyFunctor { @@ -70,8 +72,8 @@ template struct MultiplyFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) * - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) * + su_ns::sycl_complex_t(in2); } else { return in1 * in2; @@ -419,7 +421,20 @@ template struct MultiplyInplaceFunctor using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; - void operator()(resT &res, const argT &in) { res *= in; } + void operator()(resT &res, const argT &in) + { + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using res_rT = typename resT::value_type; + using arg_rT = typename argT::value_type; + + auto res1 = su_ns::sycl_complex_t(res); + res1 *= su_ns::sycl_complex_t(in); + res = res1; + } + else { + res *= in; + } + } template void operator()(sycl::vec &res, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index d7b0ed909e..f6dfc41899 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -51,8 +51,10 @@ namespace pow { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct PowFunctor { @@ -92,8 +94,8 @@ template struct PowFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::pow(exprm_ns::complex(in1), - exprm_ns::complex(in2)); + return exprm_ns::pow(su_ns::sycl_complex_t(in1), + su_ns::sycl_complex_t(in2)); } else { return sycl::pow(in1, in2); @@ -392,8 +394,8 @@ template struct PowInplaceFunctor using r_resT = typename resT::value_type; using r_argT = typename argT::value_type; - res = exprm_ns::pow(exprm_ns::complex(res), - exprm_ns::complex(in)); + res = exprm_ns::pow(su_ns::sycl_complex_t(res), + su_ns::sycl_complex_t(in)); } else { res = sycl::pow(res, in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index df5edface1..0c43865647 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -32,6 +32,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +52,9 @@ namespace proj { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -70,8 +73,11 @@ template struct ProjFunctor resT operator()(const argT &in) const { using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isinf(x)) { return value_at_infinity(y); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 9ecb822a20..096f32eec5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,6 +31,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +51,9 @@ namespace real { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::is_complex_v; @@ -71,7 +74,9 @@ template struct RealFunctor resT operator()(const argT &in) const { if constexpr (is_complex_v) { - return std::real(in); + using realT = typename argT::value_type; + using sycl_complexT = typename su_ns::sycl_complex_t; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v); @@ -174,7 +179,7 @@ template struct RealContigFactory template struct RealTypeMapFactory { - /*! @brief get typeid for output type of std::real(T x) */ + /*! @brief get typeid for output type of real(T x) */ std::enable_if_t::value, int> get() { using rT = typename RealOutputType::value_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index 0e46acba39..43e0e3c640 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -32,7 +32,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -52,6 +52,7 @@ namespace reciprocal { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +75,7 @@ template struct ReciprocalFunctor using realT = typename argT::value_type; - return realT(1) / exprm_ns::complex(in); + return realT(1) / su_ns::sycl_complex_t(in); } else { return argT(1) / in; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 7fbb20ae32..4382f0d447 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -29,6 +29,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -48,7 +49,9 @@ namespace round { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -67,14 +70,15 @@ template struct RoundFunctor resT operator()(const argT &in) const { - if constexpr (std::is_integral_v) { return in; } else if constexpr (is_complex::value) { using realT = typename argT::value_type; - return resT{round_func(std::real(in)), - round_func(std::imag(in))}; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + return resT{round_func(exprm_ns::real(z)), + round_func(exprm_ns::imag(z))}; } else { return round_func(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index baa224942f..94943b73ab 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -37,6 +37,7 @@ #include "kernels/elementwise_functions/common.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -50,6 +51,7 @@ namespace sign { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -82,7 +84,7 @@ template struct SignFunctor return resT(0); } else { - auto z = exprm_ns::complex(in); + auto z = su_ns::sycl_complex_t(in); return (z / detail::cabs(in)); } } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index e075a90a88..d4bbed564b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace sin { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,8 +74,11 @@ template struct SinFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + realT const &in_re = exprm_ns::real(z); + realT const &in_im = exprm_ns::imag(z); const bool in_re_finite = std::isfinite(in_re); const bool in_im_finite = std::isfinite(in_im); @@ -82,8 +87,7 @@ template struct SinFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { - resT res = - exprm_ns::sin(exprm_ns::complex(in)); // sin(in); + resT res = exprm_ns::sin(z); // sin(z); if (in_re == realT(0)) { res.real(sycl::copysign(realT(0), in_re)); } @@ -91,9 +95,9 @@ template struct SinFunctor } /* - * since sin(in) = -I * sinh(I * in), for special cases, - * we calculate real and imaginary parts of z = sinh(I * in) and - * then return { imag(z) , -real(z) } which is sin(in). + * since sin(z) = -I * sinh(I * z), for special cases, + * we calculate real and imaginary parts of z = sinh(I * z) and + * then return { imag(z) , -real(z) } which is sin(z). */ const realT x = -in_im; const realT y = in_re; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 23b3588a3b..6c37266781 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace sinh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -70,9 +72,10 @@ template struct SinhFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -82,7 +85,7 @@ template struct SinhFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::sinh(exprm_ns::complex(in)); + return exprm_ns::sinh(z); } /* * sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index b83ff72495..b1014a5070 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -32,7 +32,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -52,7 +52,9 @@ namespace sqrt { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct SqrtFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return exprm_ns::sqrt(exprm_ns::complex(in)); + return exprm_ns::sqrt(su_ns::sycl_complex_t(in)); } else { return sycl::sqrt(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index f9d9d848c0..b66b53d225 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,6 +50,7 @@ namespace square { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +75,7 @@ template struct SquareFunctor if constexpr (is_complex::value) { using realT = typename argT::value_type; - auto z = exprm_ns::complex(in); + auto z = su_ns::sycl_complex_t(in); return z * z; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 51a3955142..bc35026481 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,6 +29,7 @@ #include #include +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -49,8 +50,10 @@ namespace subtract { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct SubtractFunctor { @@ -62,7 +65,17 @@ template struct SubtractFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return in1 - in2; + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return su_ns::sycl_complex_t(in1) - + su_ns::sycl_complex_t(in2); + } + else { + return in1 - in2; + } } template @@ -424,7 +437,17 @@ template struct SubtractInplaceFunctor void operator()(sycl::vec &res, const sycl::vec &in) { - res -= in; + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using res_rT = typename resT::value_type; + using arg_rT = typename argT::value_type; + + auto res1 = su_ns::sycl_complex_t(res); + res1 -= su_ns::sycl_complex_t(in); + res = res1; + } + else { + res -= in; + } } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 770518f918..8e0404fe02 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace tan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -75,12 +77,14 @@ template struct TanFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); /* - * since tan(in) = -I * tanh(I * in), for special cases, - * we calculate real and imaginary parts of z = tanh(I * in) and - * return { imag(z) , -real(z) } which is tan(in). + * since tan(z) = -I * tanh(I * z), for special cases, + * we calculate real and imaginary parts of z = tanh(I * z) and + * return { imag(z) , -real(z) } which is tan(z). */ - const realT x = -std::imag(in); - const realT y = std::real(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = -exprm_ns::imag(z); + const realT y = exprm_ns::real(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -121,7 +125,7 @@ template struct TanFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tan(exprm_ns::complex(in)); // tan(in); + return exprm_ns::tan(z); // tan(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 1d06fd3c4f..9ea078f6a4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace tanh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -75,8 +77,10 @@ template struct TanhFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -115,7 +119,7 @@ template struct TanhFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tanh(exprm_ns::complex(in)); // tanh(in); + return exprm_ns::tanh(z); // tanh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index de6c9a8723..a187c75230 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -50,8 +50,10 @@ namespace true_divide { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct TrueDivideFunctor @@ -70,22 +72,22 @@ struct TrueDivideFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) / - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) / + su_ns::sycl_complex_t(in2); } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { using realT1 = typename argT1::value_type; - return exprm_ns::complex(in1) / in2; + return su_ns::sycl_complex_t(in1) / in2; } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { using realT2 = typename argT2::value_type; - return in1 / exprm_ns::complex(in2); + return in1 / su_ns::sycl_complex_t(in2); } else { return in1 / in2; @@ -435,14 +437,14 @@ template struct TrueDivideInplaceFunctor using res_rT = typename resT::value_type; using arg_rT = typename argT::value_type; - auto res1 = exprm_ns::complex(res); - res1 /= exprm_ns::complex(in); + auto res1 = su_ns::sycl_complex_t(res); + res1 /= su_ns::sycl_complex_t(in); res = res1; } else { using res_rT = typename resT::value_type; - auto res1 = exprm_ns::complex(res); + auto res1 = su_ns::sycl_complex_t(res); res1 /= in; res = res1; } diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 71e2c15b6b..c19fe6812b 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -37,6 +37,7 @@ #include "kernels/reductions.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" @@ -49,6 +50,18 @@ namespace kernels using dpctl::tensor::ssize_t; namespace su_ns = dpctl::tensor::sycl_utils; +namespace tu_ns = dpctl::tensor::type_utils; + +namespace detail +{ + +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + +} // namespace detail template ( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); + if constexpr (tu_ns::is_complex_v) { + using realT = typename outT::value_type; + using sycl_complex = su_ns::sycl_complex_t; + + auto tmp = sycl_complex(red_val); + tmp += sycl_complex(tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complex(tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + red_val = outT(tmp); + } + else { + red_val += tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } } out_[out_batch_offset] = red_val; @@ -175,10 +200,9 @@ struct DotProductFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; @@ -273,10 +297,9 @@ struct DotProductCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; @@ -718,20 +741,32 @@ struct DotProductNoAtomicFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); - - local_red_val += val; + if constexpr (tu_ns::is_complex_v) { + using realT = typename outT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + + sycl_complexT val = + sycl_complexT(tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complexT(tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + local_red_val = outT(sycl_complexT(local_red_val) + val); + } + else { + outT val = tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + local_red_val += val; + } } auto work_group = it.get_group(); - using RedOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using RedOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; outT red_val_over_wg = sycl::reduce_over_group( work_group, local_red_val, outT(0), RedOpT()); @@ -819,13 +854,24 @@ struct DotProductNoAtomicCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); - - local_red_val += val; + if constexpr (tu_ns::is_complex_v) { + using realT = typename outT::value_type; + using sycl_complexT = su_ns::sycl_complex_t; + + sycl_complexT val = + sycl_complexT(tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complexT(tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + local_red_val = outT(sycl_complexT(local_red_val) + val); + } + else { + outT val = tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + local_red_val += val; + } } auto work_group = it.get_group(); @@ -972,9 +1018,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = detail::SumTempsOpT; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1014,7 +1058,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = @@ -1215,9 +1259,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = detail::SumTempsOpT; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1261,7 +1303,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 4ad4eb142a..49b8868b20 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -37,6 +37,7 @@ #include "kernels/reductions.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" @@ -48,6 +49,8 @@ namespace kernels { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; +namespace tu_ns = dpctl::tensor::type_utils; namespace gemm_detail { @@ -96,7 +99,7 @@ void scale_gemm_nm_parameters(const std::size_t &local_mem_size, } } // namespace gemm_detail -using dpctl::tensor::sycl_utils::choose_workgroup_size; +using su_ns::choose_workgroup_size; template class gemm_seq_reduction_krn; @@ -1082,8 +1085,21 @@ class GemmBatchFunctorThreadNM_vecm #pragma unroll for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) { - private_C[pr_i * wi_delta_m_vecs + pr_j] += - pr_lhs[pr_i] * pr_rhs[pr_j]; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = su_ns::sycl_complex_t; + + auto tmp = sycl_complex( + private_C[pr_i * wi_delta_m_vecs + pr_j]); + tmp += sycl_complex(pr_lhs[pr_i]) * + sycl_complex(pr_rhs[pr_j]); + private_C[pr_i * wi_delta_m_vecs + pr_j] = + resT(tmp); + } + else { + private_C[pr_i * wi_delta_m_vecs + pr_j] += + pr_lhs[pr_i] * pr_rhs[pr_j]; + } } } } @@ -1776,6 +1792,17 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, // ========== Gemm Tree +namespace gemm_detail +{ + +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + +} // namespace gemm_detail + template ) { + using realT = typename resT::value_type; + using sycl_complex = su_ns::sycl_complex_t; + auto tmp = sycl_complex(local_sum); + tmp += (sycl_complex(local_A_block[a_offset + a_pr_offset + + private_s]) * + sycl_complex(local_B_block[b_offset + private_s])); + local_sum = resT(tmp); + } + else { + local_sum = + local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } } const std::size_t gl_i = i + private_i; @@ -2114,12 +2153,28 @@ class GemmBatchNoAtomicFunctorThreadK accV_t private_sum(identity_); constexpr accV_t vec_identity_(identity_); for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += - ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : vec_identity_; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = su_ns::sycl_complex_t; + + auto tmp = sycl_complex(private_sum); + tmp += ((i < n) && (t + t_shift < k)) + ? sycl_complex(static_cast( + lhs[lhs_offset + + lhs_indexer(global_s_offset + t)])) * + sycl_complex(local_B_block[t]) + : sycl_complex(vec_identity_); + private_sum = resT(tmp); + } + else { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } } std::size_t workspace_i_shift = local_i * delta_k; @@ -2130,7 +2185,17 @@ class GemmBatchNoAtomicFunctorThreadK if (local_s == 0 && i < n) { accV_t local_sum(workspace[workspace_i_shift]); for (std::size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = su_ns::sycl_complex_t; + + auto tmp = sycl_complex(local_sum); + tmp += sycl_complex(workspace[workspace_i_shift + t]); + local_sum = resT(tmp); + } + else { + local_sum += workspace[workspace_i_shift + t]; + } } const std::size_t total_offset = @@ -2311,12 +2376,9 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -2607,12 +2669,9 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -2863,8 +2922,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { constexpr std::uint32_t m_groups_one = 1; return gemm_batch_tree_k_impl 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { constexpr std::uint32_t m_groups_four = 4; return gemm_batch_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, @@ -2980,12 +3037,9 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -3168,12 +3222,9 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3435,8 +3486,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, @@ -3454,8 +3504,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_batch_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } @@ -3539,12 +3588,9 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -3693,12 +3739,9 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3840,8 +3883,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, @@ -3866,8 +3908,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, @@ -3929,12 +3970,9 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -4068,12 +4106,9 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -4191,8 +4226,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); @@ -4208,8 +4242,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index f056d246c9..042997f56b 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -1914,8 +1914,7 @@ struct SequentialSearchReduction using dpctl::tensor::math_utils::less_complex; // less_complex always returns false for NaNs, so check if (less_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || std::isnan(val.imag())) { red_val = val; idx_val = static_cast(m); @@ -1941,8 +1940,7 @@ struct SequentialSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || std::isnan(val.imag())) { red_val = val; idx_val = static_cast(m); @@ -2230,8 +2228,8 @@ struct CustomSearchReduction // less_complex always returns false for NaNs, so // check if (less_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || + std::isnan(val.imag())) { local_red_val = val; if constexpr (!First) { @@ -2277,8 +2275,8 @@ struct CustomSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || + std::isnan(val.imag())) { local_red_val = val; if constexpr (!First) { diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index a49b56b6ba..9097133773 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -24,9 +24,17 @@ #pragma once #include -#include #include +#ifndef SYCL_EXT_ONEAPI_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX 1 +#endif +#if __has_include() +#include +#else +#include +#endif + namespace dpctl { namespace tensor @@ -34,13 +42,18 @@ namespace tensor namespace math_utils { +namespace exprm_ns = sycl::ext::oneapi::experimental; + template bool less_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 < imag2) @@ -50,10 +63,13 @@ template bool less_complex(const T &x1, const T &x2) template bool greater_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 > imag2) @@ -63,10 +79,13 @@ template bool greater_complex(const T &x1, const T &x2) template bool less_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 <= imag2) @@ -76,10 +95,13 @@ template bool less_equal_complex(const T &x1, const T &x2) template bool greater_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 >= imag2) @@ -89,10 +111,13 @@ template bool greater_equal_complex(const T &x1, const T &x2) template T max_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool gt = (real1 == real2) @@ -104,10 +129,13 @@ template T max_complex(const T &x1, const T &x2) template T min_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool lt = (real1 == real2) @@ -133,6 +161,20 @@ template T logaddexp(T x, T y) } } +template T plus_complex(const T &x1, const T &x2) +{ + using realT = typename T::value_type; + using sycl_complexT = exprm_ns::complex; + return T(sycl_complexT(x1) + sycl_complexT(x2)); +} + +template T multiplies_complex(const T &x1, const T &x2) +{ + using realT = typename T::value_type; + using sycl_complexT = exprm_ns::complex; + return T(sycl_complexT(x1) * sycl_complexT(x2)); +} + } // namespace math_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp b/dpctl/tensor/libtensor/include/utils/sycl_complex.hpp similarity index 81% rename from dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp rename to dpctl/tensor/libtensor/include/utils/sycl_complex.hpp index 3b5a1b9e7b..535bf17241 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_complex.hpp @@ -26,11 +26,25 @@ #pragma once -#define SYCL_EXT_ONEAPI_COMPLEX +#ifndef SYCL_EXT_ONEAPI_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX 1 +#endif #if __has_include() #include #else #include #endif -namespace exprm_ns = sycl::ext::oneapi::experimental; +namespace dpctl +{ +namespace tensor +{ +namespace sycl_utils +{ + +template +using sycl_complex_t = sycl::ext::oneapi::experimental::complex; + +} // namespace sycl_utils +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index ece8852643..7d1d9a77d9 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -31,6 +31,7 @@ #include #include "math_utils.hpp" +#include "sycl_complex.hpp" namespace dpctl { @@ -298,7 +299,11 @@ T custom_inclusive_scan_over_group(GroupT &&wg, return scan_val; } -// Reduction functors +// Define identities and operator checking structs + +template struct GetIdentity +{ +}; // Maximum @@ -324,38 +329,6 @@ template struct Maximum } }; -// Minimum - -template struct Minimum -{ - T operator()(const T &x, const T &y) const - { - if constexpr (detail::IsComplex::value) { - using dpctl::tensor::math_utils::min_complex; - return min_complex(x, y); - } - else if constexpr (std::is_floating_point_v || - std::is_same_v) - { - return (std::isnan(x) || x < y) ? x : y; - } - else if constexpr (std::is_same_v) { - return x && y; - } - else { - return (x < y) ? x : y; - } - } -}; - -// Define identities and operator checking structs - -template struct GetIdentity -{ -}; - -// Maximum - template using IsMaximum = std::bool_constant> || std::is_same_v>>; @@ -389,6 +362,28 @@ struct GetIdentity struct Minimum +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::min_complex; + return min_complex(x, y); + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + return (std::isnan(x) || x < y) ? x : y; + } + else if constexpr (std::is_same_v) { + return x && y; + } + else { + return (x < y) ? x : y; + } + } +}; + template using IsMinimum = std::bool_constant> || std::is_same_v>>; @@ -422,19 +417,55 @@ struct GetIdentity struct Plus +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::plus_complex; + return plus_complex(x, y); + } + else { + return sycl::plus(x, y); + } + } +}; + template using IsPlus = std::bool_constant> || - std::is_same_v>>; + std::is_same_v> || + std::is_same_v>>; template using IsSyclPlus = std::bool_constant>>; +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(0); +}; + // Multiplies +template struct Multiplies +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::multiplies_complex; + return multiplies_complex(x, y); + } + else { + return sycl::multiplies(x, y); + } + } +}; + template using IsMultiplies = std::bool_constant> || - std::is_same_v>>; + std::is_same_v> || + std::is_same_v>>; template using IsSyclMultiplies = diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp index 045b1b330e..992e3592b6 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp @@ -46,6 +46,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -133,9 +134,12 @@ struct TypePairSupportDataForProdAccumulation }; template -using CumProdScanOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; +using CumProdScanOpT = + std::conditional_t, + sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; template struct CumProd1DContigFactory diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp index e44678e15f..22e32dfb03 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -34,6 +34,7 @@ #include "kernels/accumulators.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" namespace py = pybind11; @@ -46,6 +47,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -133,8 +135,10 @@ struct TypePairSupportDataForSumAccumulation }; template -using CumSumScanOpT = std:: - conditional_t, sycl::logical_or, sycl::plus>; +using CumSumScanOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; template struct CumSum1DContigFactory diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 7c768ce179..40a3bf4dc5 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -31,7 +31,9 @@ #include #include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -45,7 +47,9 @@ namespace tensor namespace py_internal { +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -229,6 +233,14 @@ struct TypePairSupportDataForProductReductionTemps td_ns::NotDefinedEntry>::is_defined; }; +template +using ProdTempsOpT = + std::conditional_t, + sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; + template struct ProductOverAxisAtomicStridedFactory { @@ -256,9 +268,7 @@ struct ProductOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -315,9 +325,7 @@ struct ProductOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -336,9 +344,7 @@ struct ProductOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index f449a6cde3..e9476c3dfb 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -31,7 +31,9 @@ #include #include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -45,7 +47,9 @@ namespace tensor namespace py_internal { +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -229,6 +233,12 @@ struct TypePairSupportDataForSumReductionTemps td_ns::NotDefinedEntry>::is_defined; }; +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + template struct SumOverAxisAtomicStridedFactory { @@ -256,9 +266,7 @@ struct SumOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -315,9 +323,7 @@ struct SumOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -336,9 +342,7 @@ struct SumOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp index 2aaa1cfafa..9f3cf9ffd8 100644 --- a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp +++ b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp @@ -27,6 +27,8 @@ #include "sycl/sycl.hpp" #include +#include "utils/sycl_complex.hpp" + namespace dpctl { namespace tensor @@ -36,6 +38,9 @@ namespace py_internal namespace { + +namespace su_ns = dpctl::tensor::sycl_utils; + template struct ExtendedRealFPLess { /* [R, nan] */ @@ -53,6 +58,8 @@ template struct ExtendedRealFPGreater } }; +namespace exprm_ns = sycl::ext::oneapi::experimental; + template struct ExtendedComplexFPLess { /* [(R, R), (R, nan), (nan, R), (nan, nan)] */ @@ -60,15 +67,17 @@ template struct ExtendedComplexFPLess bool operator()(const cT &v1, const cT &v2) const { using realT = typename cT::value_type; - - const realT real1 = std::real(v1); - const realT real2 = std::real(v2); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT z1 = sycl_complexT(v1); + sycl_complexT z2 = sycl_complexT(v2); + const realT real1 = exprm_ns::real(z1); + const realT real2 = exprm_ns::real(z2); const bool r1_nan = std::isnan(real1); const bool r2_nan = std::isnan(real2); - const realT imag1 = std::imag(v1); - const realT imag2 = std::imag(v2); + const realT imag1 = exprm_ns::imag(z1); + const realT imag2 = exprm_ns::imag(z2); const bool i1_nan = std::isnan(imag1); const bool i2_nan = std::isnan(imag2);