File tree 1 file changed +4
-5
lines changed
dpctl/tensor/libtensor/include/kernels
1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -47,6 +47,8 @@ namespace kernels
47
47
namespace accumulators
48
48
{
49
49
50
+ namespace su_ns = dpctl::tensor::sycl_utils;
51
+
50
52
using dpctl::tensor::ssize_t ;
51
53
using namespace dpctl ::tensor::offset_utils;
52
54
@@ -87,9 +89,8 @@ template <typename srcTy, typename dstTy> struct CastTransformer
87
89
template <typename ScanOpT, typename T> struct needs_workaround
88
90
{
89
91
// work-around needed due to crash in JITing on CPU
90
- static constexpr bool value =
91
- std::is_same_v<ScanOpT, sycl::logical_or<T>> ||
92
- std::is_same_v<ScanOpT, sycl::logical_and<T>>;
92
+ static constexpr bool value = su_ns::IsSyclLogicalAnd<T, ScanOpT>::value ||
93
+ su_ns::IsSyclLogicalOr<T, ScanOpT>::value;
93
94
};
94
95
95
96
template <typename BinOpT, typename T> struct can_use_inclusive_scan_over_group
@@ -153,8 +154,6 @@ template <typename T> class stack_strided_t
153
154
154
155
// Iterative cumulative summation
155
156
156
- namespace su_ns = dpctl::tensor::sycl_utils;
157
-
158
157
using nwiT = std::uint32_t ;
159
158
160
159
template <typename inputT,
You can’t perform that action at this time.
0 commit comments