Skip to content

Commit c363c02

Browse files
committed
sycl::vec overload for sine
1 parent 323ed6d commit c363c02

File tree

1 file changed

+16
-1
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+16
-1
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ template <typename argT, typename resT> struct SinFunctor
5757
// constant value, if constant
5858
// constexpr resT constant_value = resT{};
5959
// is function defined for sycl::vec
60-
using supports_vec = typename std::false_type;
60+
using supports_vec = typename std::negation<
61+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6162
// do both argTy and resTy support sugroup store/load operation
6263
using supports_sg_loadstore = typename std::negation<
6364
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -66,6 +67,20 @@ template <typename argT, typename resT> struct SinFunctor
6667
{
6768
return std::sin(in);
6869
}
70+
71+
template <int vec_sz>
72+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
73+
{
74+
auto const &res_vec = sycl::sin(in);
75+
using deducedT = typename std::remove_cv_t<
76+
std::remove_reference_t<decltype(res_vec)>>::element_type;
77+
if constexpr (std::is_same_v<resT, deducedT>) {
78+
return res_vec;
79+
}
80+
else {
81+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
82+
}
83+
}
6984
};
7085

7186
template <typename argTy,

0 commit comments

Comments
 (0)