Skip to content

Commit 323ed6d

Browse files
committed
Adds sycl::vec overloads to abs, cos, expm1, log, log1p, and sqrt
1 parent d109770 commit 323ed6d

File tree

6 files changed

+122
-6
lines changed

6 files changed

+122
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

+37-1
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252

5353
using dpctl::tensor::type_utils::is_complex;
54+
using dpctl::tensor::type_utils::vec_cast;
5455

5556
template <typename argT, typename resT> struct AbsFunctor
5657
{
5758

5859
using is_constant = typename std::false_type;
5960
// constexpr resT constant_value = resT{};
60-
using supports_vec = typename std::false_type;
61+
using supports_vec = typename std::negation<
62+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6163
using supports_sg_loadstore = typename std::negation<
6264
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6365

@@ -75,6 +77,40 @@ template <typename argT, typename resT> struct AbsFunctor
7577
return std::abs(x);
7678
}
7779
}
80+
81+
template <int vec_sz>
82+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
83+
{
84+
if constexpr (std::is_integral<argT>::value) {
85+
if constexpr (std::is_same_v<argT, bool> ||
86+
std::is_unsigned<argT>::value) {
87+
return in;
88+
}
89+
else {
90+
auto const &res_vec = sycl::abs(in);
91+
using deducedT = typename std::remove_cv_t<
92+
std::remove_reference_t<decltype(res_vec)>>::element_type;
93+
if constexpr (std::is_same_v<resT, deducedT>) {
94+
return res_vec;
95+
}
96+
else {
97+
98+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
99+
}
100+
}
101+
}
102+
else {
103+
auto const &res_vec = sycl::fabs(in);
104+
using deducedT = typename std::remove_cv_t<
105+
std::remove_reference_t<decltype(res_vec)>>::element_type;
106+
if constexpr (std::is_same_v<resT, deducedT>) {
107+
return res_vec;
108+
}
109+
else {
110+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
111+
}
112+
}
113+
}
78114
};
79115

80116
template <typename argT,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace py = pybind11;
4949
namespace td_ns = dpctl::tensor::type_dispatch;
5050

5151
using dpctl::tensor::type_utils::is_complex;
52+
using dpctl::tensor::type_utils::vec_cast;
5253

5354
template <typename argT, typename resT> struct CosFunctor
5455
{
@@ -58,7 +59,8 @@ template <typename argT, typename resT> struct CosFunctor
5859
// constant value, if constant
5960
// constexpr resT constant_value = resT{};
6061
// is function defined for sycl::vec
61-
using supports_vec = typename std::false_type;
62+
using supports_vec = typename std::negation<
63+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6264
// do both argTy and resTy support sugroup store/load operation
6365
using supports_sg_loadstore = typename std::negation<
6466
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -67,6 +69,20 @@ template <typename argT, typename resT> struct CosFunctor
6769
{
6870
return std::cos(in);
6971
}
72+
73+
template <int vec_sz>
74+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
75+
{
76+
auto const &res_vec = sycl::cos(in);
77+
using deducedT = typename std::remove_cv_t<
78+
std::remove_reference_t<decltype(res_vec)>>::element_type;
79+
if constexpr (std::is_same_v<resT, deducedT>) {
80+
return res_vec;
81+
}
82+
else {
83+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
84+
}
85+
}
7086
};
7187

7288
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
template <typename argT, typename resT> struct Expm1Functor
5556
{
@@ -59,7 +60,8 @@ template <typename argT, typename resT> struct Expm1Functor
5960
// constant value, if constant
6061
// constexpr resT constant_value = resT{};
6162
// is function defined for sycl::vec
62-
using supports_vec = typename std::false_type;
63+
using supports_vec = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6365
// do both argTy and resTy support sugroup store/load operation
6466
using supports_sg_loadstore = typename std::negation<
6567
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -88,6 +90,20 @@ template <typename argT, typename resT> struct Expm1Functor
8890
return std::expm1(in);
8991
}
9092
}
93+
94+
template <int vec_sz>
95+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
96+
{
97+
auto const &res_vec = sycl::expm1(in);
98+
using deducedT = typename std::remove_cv_t<
99+
std::remove_reference_t<decltype(res_vec)>>::element_type;
100+
if constexpr (std::is_same_v<resT, deducedT>) {
101+
return res_vec;
102+
}
103+
else {
104+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
105+
}
106+
}
91107
};
92108

93109
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace py = pybind11;
4949
namespace td_ns = dpctl::tensor::type_dispatch;
5050

5151
using dpctl::tensor::type_utils::is_complex;
52+
using dpctl::tensor::type_utils::vec_cast;
5253

5354
template <typename argT, typename resT> struct LogFunctor
5455
{
@@ -58,7 +59,8 @@ template <typename argT, typename resT> struct LogFunctor
5859
// constant value, if constant
5960
// constexpr resT constant_value = resT{};
6061
// is function defined for sycl::vec
61-
using supports_vec = typename std::false_type;
62+
using supports_vec = typename std::negation<
63+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6264
// do both argTy and resTy support sugroup store/load operation
6365
using supports_sg_loadstore = typename std::negation<
6466
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -67,6 +69,20 @@ template <typename argT, typename resT> struct LogFunctor
6769
{
6870
return std::log(in);
6971
}
72+
73+
template <int vec_sz>
74+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
75+
{
76+
auto const &res_vec = sycl::log(in);
77+
using deducedT = typename std::remove_cv_t<
78+
std::remove_reference_t<decltype(res_vec)>>::element_type;
79+
if constexpr (std::is_same_v<resT, deducedT>) {
80+
return res_vec;
81+
}
82+
else {
83+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
84+
}
85+
}
7086
};
7187

7288
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
// TODO: evaluate precision against alternatives
5556
template <typename argT, typename resT> struct Log1pFunctor
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Log1pFunctor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -99,6 +101,20 @@ template <typename argT, typename resT> struct Log1pFunctor
99101
return std::log1p(in);
100102
}
101103
}
104+
105+
template <int vec_sz>
106+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
107+
{
108+
auto const &res_vec = sycl::log1p(in);
109+
using deducedT = typename std::remove_cv_t<
110+
std::remove_reference_t<decltype(res_vec)>>::element_type;
111+
if constexpr (std::is_same_v<resT, deducedT>) {
112+
return res_vec;
113+
}
114+
else {
115+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
116+
}
117+
}
102118
};
103119

104120
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
template <typename argT, typename resT> struct SqrtFunctor
5556
{
@@ -59,7 +60,8 @@ template <typename argT, typename resT> struct SqrtFunctor
5960
// constant value, if constant
6061
// constexpr resT constant_value = resT{};
6162
// is function defined for sycl::vec
62-
using supports_vec = typename std::false_type;
63+
using supports_vec = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6365
// do both argTy and resTy support sugroup store/load operation
6466
using supports_sg_loadstore = typename std::negation<
6567
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -68,6 +70,20 @@ template <typename argT, typename resT> struct SqrtFunctor
6870
{
6971
return std::sqrt(in);
7072
}
73+
74+
template <int vec_sz>
75+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
76+
{
77+
auto const &res_vec = sycl::sqrt(in);
78+
using deducedT = typename std::remove_cv_t<
79+
std::remove_reference_t<decltype(res_vec)>>::element_type;
80+
if constexpr (std::is_same_v<resT, deducedT>) {
81+
return res_vec;
82+
}
83+
else {
84+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
85+
}
86+
}
7187
};
7288

7389
template <typename argTy,

0 commit comments

Comments
 (0)