From e109ce36dd93649f0b4299e788497d034366010b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 1 Jun 2023 11:22:20 -0700 Subject: [PATCH 1/3] Adds sycl::vec overloads to abs, cos, expm1, log, log1p, and sqrt --- .../kernels/elementwise_functions/abs.hpp | 38 ++++++++++++++++++- .../kernels/elementwise_functions/cos.hpp | 18 ++++++++- .../kernels/elementwise_functions/expm1.hpp | 18 ++++++++- .../kernels/elementwise_functions/log.hpp | 18 ++++++++- .../kernels/elementwise_functions/log1p.hpp | 18 ++++++++- .../kernels/elementwise_functions/sqrt.hpp | 18 ++++++++- 6 files changed, 122 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 911452931e..791fa99f63 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -52,13 +52,15 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; template struct AbsFunctor { using is_constant = typename std::false_type; // constexpr resT constant_value = resT{}; - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -87,6 +89,40 @@ template struct AbsFunctor } } + template + sycl::vec operator()(const sycl::vec &in) + { + if constexpr (std::is_integral::value) { + if constexpr (std::is_same_v || + std::is_unsigned::value) { + return in; + } + else { + auto const &res_vec = sycl::abs(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + + return vec_cast(res_vec); + } + } + } + else { + auto const &res_vec = sycl::fabs(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } + } + private: template realT cabs(std::complex const &z) const { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index c8cd8ef18c..8c05a8a4fd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -50,6 +50,7 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; template struct CosFunctor { @@ -59,7 +60,8 @@ template struct CosFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -165,6 +167,20 @@ template struct CosFunctor return std::cos(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::cos(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct Expm1Functor { @@ -60,7 +61,8 @@ template struct Expm1Functor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -132,6 +134,20 @@ template struct Expm1Functor return std::expm1(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::expm1(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct LogFunctor { @@ -60,7 +61,8 @@ template struct LogFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -79,6 +81,20 @@ template struct LogFunctor return std::log(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::log(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct Log1pFunctor @@ -60,7 +61,8 @@ template struct Log1pFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -99,6 +101,20 @@ template struct Log1pFunctor return std::log1p(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::log1p(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template struct SqrtFunctor { @@ -62,7 +63,8 @@ template struct SqrtFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -95,6 +97,20 @@ template struct SqrtFunctor } } + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::sqrt(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } + private: template std::complex csqrt(std::complex const &z) const { From 00a33b03a938b0ab93ec317f1ef35277ca328234 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 22 Jun 2023 13:59:07 -0700 Subject: [PATCH 2/3] sycl::vec overload for sine --- .../kernels/elementwise_functions/sin.hpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 97768fc8e9..1ef724eebc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -58,7 +58,8 @@ template struct SinFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -181,6 +182,20 @@ template struct SinFunctor return std::sin(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::sin(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template Date: Mon, 10 Jul 2023 10:33:54 -0500 Subject: [PATCH 3/3] Import definition of vec_cast in sin.hpp --- .../libtensor/include/kernels/elementwise_functions/sin.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 1ef724eebc..32772d29f9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -50,6 +50,7 @@ namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; template struct SinFunctor {