Skip to content

Commit 0d0ff97

Browse files
authored
Merge pull request #2018 from IntelPython/resolve-gh-2017
Fix incorrect results from `cumulative_sum` and `cumulative_prod` with `dtype=bool`
2 parents c6b00c7 + fa525a7 commit 0d0ff97

File tree

4 files changed

+45
-11
lines changed

4 files changed

+45
-11
lines changed

Diff for: dpctl/tensor/libtensor/include/kernels/accumulators.hpp

+12-3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ namespace kernels
4747
namespace accumulators
4848
{
4949

50+
namespace su_ns = dpctl::tensor::sycl_utils;
51+
5052
using dpctl::tensor::ssize_t;
5153
using namespace dpctl::tensor::offset_utils;
5254

@@ -84,9 +86,18 @@ template <typename srcTy, typename dstTy> struct CastTransformer
8486
}
8587
};
8688

89+
template <typename ScanOpT, typename T> struct needs_workaround
90+
{
91+
// workaround needed due to crash in JITing on CPU
92+
// remove when CMPLRLLVM-65813 is resolved
93+
static constexpr bool value = su_ns::IsSyclLogicalAnd<T, ScanOpT>::value ||
94+
su_ns::IsSyclLogicalOr<T, ScanOpT>::value;
95+
};
96+
8797
template <typename BinOpT, typename T> struct can_use_inclusive_scan_over_group
8898
{
89-
static constexpr bool value = sycl::has_known_identity<BinOpT, T>::value;
99+
static constexpr bool value = sycl::has_known_identity<BinOpT, T>::value &&
100+
!needs_workaround<BinOpT, T>::value;
90101
};
91102

92103
namespace detail
@@ -144,8 +155,6 @@ template <typename T> class stack_strided_t
144155

145156
// Iterative cumulative summation
146157

147-
namespace su_ns = dpctl::tensor::sycl_utils;
148-
149158
using nwiT = std::uint32_t;
150159

151160
template <typename inputT,

Diff for: dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070
struct TypePairSupportDataForProdAccumulation
7171
{
7272
static constexpr bool is_defined = std::disjunction<
73+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
7374
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
7475
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
7576

7677
// input int8_t
78+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
7779
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
7880
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
7981

@@ -130,6 +132,11 @@ struct TypePairSupportDataForProdAccumulation
130132
td_ns::NotDefinedEntry>::is_defined;
131133
};
132134

135+
template <typename T>
136+
using CumProdScanOpT = std::conditional_t<std::is_same_v<T, bool>,
137+
sycl::logical_and<T>,
138+
sycl::multiplies<T>>;
139+
133140
template <typename fnT, typename srcTy, typename dstTy>
134141
struct CumProd1DContigFactory
135142
{
@@ -138,7 +145,7 @@ struct CumProd1DContigFactory
138145
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
139146
dstTy>::is_defined)
140147
{
141-
using ScanOpT = sycl::multiplies<dstTy>;
148+
using ScanOpT = CumProdScanOpT<dstTy>;
142149
constexpr bool include_initial = false;
143150
if constexpr (std::is_same_v<srcTy, dstTy>) {
144151
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +178,7 @@ struct CumProd1DIncludeInitialContigFactory
171178
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
172179
dstTy>::is_defined)
173180
{
174-
using ScanOpT = sycl::multiplies<dstTy>;
181+
using ScanOpT = CumProdScanOpT<dstTy>;
175182
constexpr bool include_initial = true;
176183
if constexpr (std::is_same_v<srcTy, dstTy>) {
177184
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +211,7 @@ struct CumProdStridedFactory
204211
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
205212
dstTy>::is_defined)
206213
{
207-
using ScanOpT = sycl::multiplies<dstTy>;
214+
using ScanOpT = CumProdScanOpT<dstTy>;
208215
constexpr bool include_initial = false;
209216
if constexpr (std::is_same_v<srcTy, dstTy>) {
210217
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +244,7 @@ struct CumProdIncludeInitialStridedFactory
237244
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
238245
dstTy>::is_defined)
239246
{
240-
using ScanOpT = sycl::multiplies<dstTy>;
247+
using ScanOpT = CumProdScanOpT<dstTy>;
241248
constexpr bool include_initial = true;
242249
if constexpr (std::is_same_v<srcTy, dstTy>) {
243250
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

Diff for: dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070
struct TypePairSupportDataForSumAccumulation
7171
{
7272
static constexpr bool is_defined = std::disjunction<
73+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
7374
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
7475
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
7576

7677
// input int8_t
78+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
7779
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
7880
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
7981

@@ -130,6 +132,10 @@ struct TypePairSupportDataForSumAccumulation
130132
td_ns::NotDefinedEntry>::is_defined;
131133
};
132134

135+
template <typename T>
136+
using CumSumScanOpT = std::
137+
conditional_t<std::is_same_v<T, bool>, sycl::logical_or<T>, sycl::plus<T>>;
138+
133139
template <typename fnT, typename srcTy, typename dstTy>
134140
struct CumSum1DContigFactory
135141
{
@@ -138,7 +144,7 @@ struct CumSum1DContigFactory
138144
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
139145
dstTy>::is_defined)
140146
{
141-
using ScanOpT = sycl::plus<dstTy>;
147+
using ScanOpT = CumSumScanOpT<dstTy>;
142148
constexpr bool include_initial = false;
143149
if constexpr (std::is_same_v<srcTy, dstTy>) {
144150
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +177,7 @@ struct CumSum1DIncludeInitialContigFactory
171177
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
172178
dstTy>::is_defined)
173179
{
174-
using ScanOpT = sycl::plus<dstTy>;
180+
using ScanOpT = CumSumScanOpT<dstTy>;
175181
constexpr bool include_initial = true;
176182
if constexpr (std::is_same_v<srcTy, dstTy>) {
177183
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,7 @@ struct CumSumStridedFactory
204210
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
205211
dstTy>::is_defined)
206212
{
207-
using ScanOpT = sycl::plus<dstTy>;
213+
using ScanOpT = CumSumScanOpT<dstTy>;
208214
constexpr bool include_initial = false;
209215
if constexpr (std::is_same_v<srcTy, dstTy>) {
210216
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +243,7 @@ struct CumSumIncludeInitialStridedFactory
237243
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
238244
dstTy>::is_defined)
239245
{
240-
using ScanOpT = sycl::plus<dstTy>;
246+
using ScanOpT = CumSumScanOpT<dstTy>;
241247
constexpr bool include_initial = true;
242248
if constexpr (std::is_same_v<srcTy, dstTy>) {
243249
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

Diff for: dpctl/tests/test_tensor_accumulation.py

+12
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,15 @@ def test_cumulative_sum_gh_1901(p):
421421
inp = dpt.ones(n, dtype=dt)
422422
r = dpt.cumulative_sum(inp, dtype=dt)
423423
assert dpt.all(r == dpt.arange(1, n + 1, dtype=dt))
424+
425+
426+
@pytest.mark.parametrize(
427+
"dt", ["i1", "i2", "i4", "i8", "f2", "f4", "f8", "c8", "c16"]
428+
)
429+
def test_gh_2017(dt):
430+
"See https://github.com/IntelPython/dpctl/issues/2017"
431+
q = get_queue_or_skip()
432+
skip_if_dtype_not_supported(dt, q)
433+
x = dpt.asarray([-1, 1], dtype=dpt.dtype(dt), sycl_queue=q)
434+
r = dpt.cumulative_sum(x, dtype="?")
435+
assert dpt.all(r)

0 commit comments

Comments
 (0)