Skip to content

Commit c58d775

Browse files
committed
Reuse SYCL utils in workaround for logical operators in accumulators
1 parent 48b0e66 commit c58d775

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ namespace kernels
4747
namespace accumulators
4848
{
4949

50+
namespace su_ns = dpctl::tensor::sycl_utils;
51+
5052
using dpctl::tensor::ssize_t;
5153
using namespace dpctl::tensor::offset_utils;
5254

@@ -87,9 +89,8 @@ template <typename srcTy, typename dstTy> struct CastTransformer
8789
template <typename ScanOpT, typename T> struct needs_workaround
8890
{
8991
// work-around needed due to crash in JITing on CPU
90-
static constexpr bool value =
91-
std::is_same_v<ScanOpT, sycl::logical_or<T>> ||
92-
std::is_same_v<ScanOpT, sycl::logical_and<T>>;
92+
static constexpr bool value = su_ns::IsSyclLogicalAnd<T, ScanOpT>::value ||
93+
su_ns::IsSyclLogicalOr<T, ScanOpT>::value;
9394
};
9495

9596
template <typename BinOpT, typename T> struct can_use_inclusive_scan_over_group
@@ -153,8 +154,6 @@ template <typename T> class stack_strided_t
153154

154155
// Iterative cumulative summation
155156

156-
namespace su_ns = dpctl::tensor::sycl_utils;
157-
158157
using nwiT = std::uint32_t;
159158

160159
template <typename inputT,

0 commit comments

Comments
 (0)