@@ -132,6 +132,11 @@ struct TypePairSupportDataForProdAccumulation
132
132
td_ns::NotDefinedEntry>::is_defined;
133
133
};
134
134
135
+ template <typename T>
136
+ using CumProdScanOpT = std::conditional_t <std::is_same_v<T, bool >,
137
+ sycl::logical_and<T>,
138
+ sycl::multiplies<T>>;
139
+
135
140
template <typename fnT, typename srcTy, typename dstTy>
136
141
struct CumProd1DContigFactory
137
142
{
@@ -140,9 +145,7 @@ struct CumProd1DContigFactory
140
145
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
141
146
dstTy>::is_defined)
142
147
{
143
- using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
144
- sycl::logical_and<dstTy>,
145
- sycl::multiplies<dstTy>>;
148
+ using ScanOpT = CumProdScanOpT<dstTy>;
146
149
constexpr bool include_initial = false ;
147
150
if constexpr (std::is_same_v<srcTy, dstTy>) {
148
151
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -175,9 +178,7 @@ struct CumProd1DIncludeInitialContigFactory
175
178
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
176
179
dstTy>::is_defined)
177
180
{
178
- using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
179
- sycl::logical_and<dstTy>,
180
- sycl::multiplies<dstTy>>;
181
+ using ScanOpT = CumProdScanOpT<dstTy>;
181
182
constexpr bool include_initial = true ;
182
183
if constexpr (std::is_same_v<srcTy, dstTy>) {
183
184
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -210,9 +211,7 @@ struct CumProdStridedFactory
210
211
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
211
212
dstTy>::is_defined)
212
213
{
213
- using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
214
- sycl::logical_and<dstTy>,
215
- sycl::multiplies<dstTy>>;
214
+ using ScanOpT = CumProdScanOpT<dstTy>;
216
215
constexpr bool include_initial = false ;
217
216
if constexpr (std::is_same_v<srcTy, dstTy>) {
218
217
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -245,9 +244,7 @@ struct CumProdIncludeInitialStridedFactory
245
244
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
246
245
dstTy>::is_defined)
247
246
{
248
- using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
249
- sycl::logical_and<dstTy>,
250
- sycl::multiplies<dstTy>>;
247
+ using ScanOpT = CumProdScanOpT<dstTy>;
251
248
constexpr bool include_initial = true ;
252
249
if constexpr (std::is_same_v<srcTy, dstTy>) {
253
250
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
0 commit comments