@@ -51,13 +51,15 @@ namespace py = pybind11;
51
51
namespace td_ns = dpctl::tensor::type_dispatch;
52
52
53
53
using dpctl::tensor::type_utils::is_complex;
54
+ using dpctl::tensor::type_utils::vec_cast;
54
55
55
56
template <typename argT, typename resT> struct AbsFunctor
56
57
{
57
58
58
59
using is_constant = typename std::false_type;
59
60
// constexpr resT constant_value = resT{};
60
- using supports_vec = typename std::false_type;
61
+ using supports_vec = typename std::negation<
62
+ std::disjunction<is_complex<resT>, is_complex<argT>>>;
61
63
using supports_sg_loadstore = typename std::negation<
62
64
std::disjunction<is_complex<resT>, is_complex<argT>>>;
63
65
@@ -75,6 +77,40 @@ template <typename argT, typename resT> struct AbsFunctor
75
77
return std::abs (x);
76
78
}
77
79
}
80
+
81
+ template <int vec_sz>
82
+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT, vec_sz> &in)
83
+ {
84
+ if constexpr (std::is_integral<argT>::value) {
85
+ if constexpr (std::is_same_v<argT, bool > ||
86
+ std::is_unsigned<argT>::value) {
87
+ return in;
88
+ }
89
+ else {
90
+ auto const &res_vec = sycl::abs (in);
91
+ using deducedT = typename std::remove_cv_t <
92
+ std::remove_reference_t <decltype (res_vec)>>::element_type;
93
+ if constexpr (std::is_same_v<resT, deducedT>) {
94
+ return res_vec;
95
+ }
96
+ else {
97
+
98
+ return vec_cast<resT, deducedT, vec_sz>(res_vec);
99
+ }
100
+ }
101
+ }
102
+ else {
103
+ auto const &res_vec = sycl::fabs (in);
104
+ using deducedT = typename std::remove_cv_t <
105
+ std::remove_reference_t <decltype (res_vec)>>::element_type;
106
+ if constexpr (std::is_same_v<resT, deducedT>) {
107
+ return res_vec;
108
+ }
109
+ else {
110
+ return vec_cast<resT, deducedT, vec_sz>(res_vec);
111
+ }
112
+ }
113
+ }
78
114
};
79
115
80
116
template <typename argT,
0 commit comments