Skip to content

Commit e2d2b3f

Browse files
committed
Add missing kernels to arithmetic accumulators
bool->bool and int8->int8 overloads were missing in cumulative_sum and cumulative_prod
1 parent 5e1c87d commit e2d2b3f

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ template <typename srcTy, typename dstTy> struct CastTransformer
8484
}
8585
};
8686

87+
template <typename ScanOpT, typename T> struct needs_workaround
88+
{
89+
static constexpr bool value =
90+
std::is_same_v<ScanOpT, sycl::logical_or<T>> ||
91+
std::is_same_v<ScanOpT, sycl::logical_and<T>>;
92+
};
93+
8794
template <typename BinOpT, typename T> struct can_use_inclusive_scan_over_group
8895
{
89-
static constexpr bool value = sycl::has_known_identity<BinOpT, T>::value;
96+
static constexpr bool value = sycl::has_known_identity<BinOpT, T>::value &&
97+
!needs_workaround<BinOpT, T>::value;
9098
};
9199

92100
namespace detail

dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070
struct TypePairSupportDataForProdAccumulation
7171
{
7272
static constexpr bool is_defined = std::disjunction<
73+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
7374
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
7475
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
7576

7677
// input int8_t
78+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
7779
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
7880
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
7981

@@ -138,7 +140,9 @@ struct CumProd1DContigFactory
138140
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
139141
dstTy>::is_defined)
140142
{
141-
using ScanOpT = sycl::multiplies<dstTy>;
143+
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
144+
sycl::logical_and<dstTy>,
145+
sycl::multiplies<dstTy>>;
142146
constexpr bool include_initial = false;
143147
if constexpr (std::is_same_v<srcTy, dstTy>) {
144148
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +175,9 @@ struct CumProd1DIncludeInitialContigFactory
171175
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
172176
dstTy>::is_defined)
173177
{
174-
using ScanOpT = sycl::multiplies<dstTy>;
178+
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
179+
sycl::logical_and<dstTy>,
180+
sycl::multiplies<dstTy>>;
175181
constexpr bool include_initial = true;
176182
if constexpr (std::is_same_v<srcTy, dstTy>) {
177183
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,9 @@ struct CumProdStridedFactory
204210
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
205211
dstTy>::is_defined)
206212
{
207-
using ScanOpT = sycl::multiplies<dstTy>;
213+
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
214+
sycl::logical_and<dstTy>,
215+
sycl::multiplies<dstTy>>;
208216
constexpr bool include_initial = false;
209217
if constexpr (std::is_same_v<srcTy, dstTy>) {
210218
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +245,9 @@ struct CumProdIncludeInitialStridedFactory
237245
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
238246
dstTy>::is_defined)
239247
{
240-
using ScanOpT = sycl::multiplies<dstTy>;
248+
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
249+
sycl::logical_and<dstTy>,
250+
sycl::multiplies<dstTy>>;
241251
constexpr bool include_initial = true;
242252
if constexpr (std::is_same_v<srcTy, dstTy>) {
243253
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070
struct TypePairSupportDataForSumAccumulation
7171
{
7272
static constexpr bool is_defined = std::disjunction<
73+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
7374
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
7475
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
7576

7677
// input int8_t
78+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
7779
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
7880
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
7981

@@ -138,7 +140,9 @@ struct CumSum1DContigFactory
138140
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
139141
dstTy>::is_defined)
140142
{
141-
using ScanOpT = sycl::plus<dstTy>;
143+
using ScanOpT =
144+
std::conditional_t<std::is_same_v<dstTy, bool>,
145+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
142146
constexpr bool include_initial = false;
143147
if constexpr (std::is_same_v<srcTy, dstTy>) {
144148
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +175,9 @@ struct CumSum1DIncludeInitialContigFactory
171175
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
172176
dstTy>::is_defined)
173177
{
174-
using ScanOpT = sycl::plus<dstTy>;
178+
using ScanOpT =
179+
std::conditional_t<std::is_same_v<dstTy, bool>,
180+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
175181
constexpr bool include_initial = true;
176182
if constexpr (std::is_same_v<srcTy, dstTy>) {
177183
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,9 @@ struct CumSumStridedFactory
204210
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
205211
dstTy>::is_defined)
206212
{
207-
using ScanOpT = sycl::plus<dstTy>;
213+
using ScanOpT =
214+
std::conditional_t<std::is_same_v<dstTy, bool>,
215+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
208216
constexpr bool include_initial = false;
209217
if constexpr (std::is_same_v<srcTy, dstTy>) {
210218
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +245,9 @@ struct CumSumIncludeInitialStridedFactory
237245
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
238246
dstTy>::is_defined)
239247
{
240-
using ScanOpT = sycl::plus<dstTy>;
248+
using ScanOpT =
249+
std::conditional_t<std::is_same_v<dstTy, bool>,
250+
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
241251
constexpr bool include_initial = true;
242252
if constexpr (std::is_same_v<srcTy, dstTy>) {
243253
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

0 commit comments

Comments
 (0)