Skip to content

Commit 35d46d3

Browse files
committed
Alias conditional binary op type for cumulative_sum and cumulative_prod
Reduces code repetition
1 parent c58d775 commit 35d46d3

File tree

2 files changed

+17
-24
lines changed

2 files changed

+17
-24
lines changed

Diff for: dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp

+9-12
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ struct TypePairSupportDataForProdAccumulation
132132
td_ns::NotDefinedEntry>::is_defined;
133133
};
134134

135+
template <typename T>
136+
using CumProdScanOpT = std::conditional_t<std::is_same_v<T, bool>,
137+
sycl::logical_and<T>,
138+
sycl::multiplies<T>>;
139+
135140
template <typename fnT, typename srcTy, typename dstTy>
136141
struct CumProd1DContigFactory
137142
{
@@ -140,9 +145,7 @@ struct CumProd1DContigFactory
140145
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
141146
dstTy>::is_defined)
142147
{
143-
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
144-
sycl::logical_and<dstTy>,
145-
sycl::multiplies<dstTy>>;
148+
using ScanOpT = CumProdScanOpT<dstTy>;
146149
constexpr bool include_initial = false;
147150
if constexpr (std::is_same_v<srcTy, dstTy>) {
148151
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -175,9 +178,7 @@ struct CumProd1DIncludeInitialContigFactory
175178
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
176179
dstTy>::is_defined)
177180
{
178-
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
179-
sycl::logical_and<dstTy>,
180-
sycl::multiplies<dstTy>>;
181+
using ScanOpT = CumProdScanOpT<dstTy>;
181182
constexpr bool include_initial = true;
182183
if constexpr (std::is_same_v<srcTy, dstTy>) {
183184
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -210,9 +211,7 @@ struct CumProdStridedFactory
210211
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
211212
dstTy>::is_defined)
212213
{
213-
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
214-
sycl::logical_and<dstTy>,
215-
sycl::multiplies<dstTy>>;
214+
using ScanOpT = CumProdScanOpT<dstTy>;
216215
constexpr bool include_initial = false;
217216
if constexpr (std::is_same_v<srcTy, dstTy>) {
218217
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -245,9 +244,7 @@ struct CumProdIncludeInitialStridedFactory
245244
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
246245
dstTy>::is_defined)
247246
{
248-
using ScanOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
249-
sycl::logical_and<dstTy>,
250-
sycl::multiplies<dstTy>>;
247+
using ScanOpT = CumProdScanOpT<dstTy>;
251248
constexpr bool include_initial = true;
252249
if constexpr (std::is_same_v<srcTy, dstTy>) {
253250
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

Diff for: dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ struct TypePairSupportDataForSumAccumulation
132132
td_ns::NotDefinedEntry>::is_defined;
133133
};
134134

135+
template <typename T>
136+
using CumSumScanOpT = std::
137+
conditional_t<std::is_same_v<T, bool>, sycl::logical_or<T>, sycl::plus<T>>;
138+
135139
template <typename fnT, typename srcTy, typename dstTy>
136140
struct CumSum1DContigFactory
137141
{
@@ -140,9 +144,7 @@ struct CumSum1DContigFactory
140144
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
141145
dstTy>::is_defined)
142146
{
143-
using ScanOpT =
144-
std::conditional_t<std::is_same_v<dstTy, bool>,
145-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
147+
using ScanOpT = CumSumScanOpT<dstTy>;
146148
constexpr bool include_initial = false;
147149
if constexpr (std::is_same_v<srcTy, dstTy>) {
148150
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -175,9 +177,7 @@ struct CumSum1DIncludeInitialContigFactory
175177
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
176178
dstTy>::is_defined)
177179
{
178-
using ScanOpT =
179-
std::conditional_t<std::is_same_v<dstTy, bool>,
180-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
180+
using ScanOpT = CumSumScanOpT<dstTy>;
181181
constexpr bool include_initial = true;
182182
if constexpr (std::is_same_v<srcTy, dstTy>) {
183183
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -210,9 +210,7 @@ struct CumSumStridedFactory
210210
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
211211
dstTy>::is_defined)
212212
{
213-
using ScanOpT =
214-
std::conditional_t<std::is_same_v<dstTy, bool>,
215-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
213+
using ScanOpT = CumSumScanOpT<dstTy>;
216214
constexpr bool include_initial = false;
217215
if constexpr (std::is_same_v<srcTy, dstTy>) {
218216
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -245,9 +243,7 @@ struct CumSumIncludeInitialStridedFactory
245243
if constexpr (TypePairSupportDataForSumAccumulation<srcTy,
246244
dstTy>::is_defined)
247245
{
248-
using ScanOpT =
249-
std::conditional_t<std::is_same_v<dstTy, bool>,
250-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
246+
using ScanOpT = CumSumScanOpT<dstTy>;
251247
constexpr bool include_initial = true;
252248
if constexpr (std::is_same_v<srcTy, dstTy>) {
253249
using dpctl::tensor::kernels::accumulators::NoOpTransformer;

0 commit comments

Comments
 (0)