@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
70
70
struct TypePairSupportDataForProdAccumulation
71
71
{
72
72
static constexpr bool is_defined = std::disjunction<
73
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, bool >,
73
74
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
74
75
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
75
76
76
77
// input int8_t
78
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
77
79
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
78
80
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
79
81
@@ -138,7 +140,9 @@ struct CumProd1DContigFactory
138
140
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
139
141
dstTy>::is_defined)
140
142
{
141
- using ScanOpT = sycl::multiplies<dstTy>;
143
+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
144
+ sycl::logical_and<dstTy>,
145
+ sycl::multiplies<dstTy>>;
142
146
constexpr bool include_initial = false ;
143
147
if constexpr (std::is_same_v<srcTy, dstTy>) {
144
148
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +175,9 @@ struct CumProd1DIncludeInitialContigFactory
171
175
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
172
176
dstTy>::is_defined)
173
177
{
174
- using ScanOpT = sycl::multiplies<dstTy>;
178
+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
179
+ sycl::logical_and<dstTy>,
180
+ sycl::multiplies<dstTy>>;
175
181
constexpr bool include_initial = true ;
176
182
if constexpr (std::is_same_v<srcTy, dstTy>) {
177
183
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,9 @@ struct CumProdStridedFactory
204
210
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
205
211
dstTy>::is_defined)
206
212
{
207
- using ScanOpT = sycl::multiplies<dstTy>;
213
+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
214
+ sycl::logical_and<dstTy>,
215
+ sycl::multiplies<dstTy>>;
208
216
constexpr bool include_initial = false ;
209
217
if constexpr (std::is_same_v<srcTy, dstTy>) {
210
218
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +245,9 @@ struct CumProdIncludeInitialStridedFactory
237
245
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
238
246
dstTy>::is_defined)
239
247
{
240
- using ScanOpT = sycl::multiplies<dstTy>;
248
+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
249
+ sycl::logical_and<dstTy>,
250
+ sycl::multiplies<dstTy>>;
241
251
constexpr bool include_initial = true ;
242
252
if constexpr (std::is_same_v<srcTy, dstTy>) {
243
253
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
0 commit comments