@@ -33,9 +33,9 @@ inline bool is_outer_reduction(const int64_t* strides) {
33
33
strides[3 ] == sizeof (typename traits::arg2_t );
34
34
}
35
35
36
- template <typename func_t , typename vec_func_t , bool reduce >
36
+ template <typename func_t , typename vec_func_t >
37
37
inline void vectorized_reduction (char ** data, int64_t n, int64_t stride,
38
- func_t op [[maybe_unused]] , vec_func_t vop) {
38
+ func_t op, vec_func_t vop, bool reduce ) {
39
39
VEC_LOOP_HEADER (func_t , data)
40
40
const char * in1_ptr = data[1 ];
41
41
Vec acc[4 ];
@@ -49,7 +49,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
49
49
acc[2 ] = vop (acc[2 ], Vec::loadu (ptr + (2 * Vec::size () * sizeof (scalar_t ))));
50
50
acc[3 ] = vop (acc[3 ], Vec::loadu (ptr + (3 * Vec::size () * sizeof (scalar_t ))));
51
51
}
52
- if constexpr (reduce) {
52
+ if (reduce) {
53
53
scalar_t buffer[Vec::size ()];
54
54
acc[0 ] = vop (vop (acc[0 ], acc[1 ]), vop (acc[2 ], acc[3 ]));
55
55
acc[0 ].store (buffer);
@@ -83,7 +83,7 @@ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_fu
83
83
constexpr int64_t vector_stride = 4 * Vec::size () * sizeof (scalar_t );
84
84
int64_t count = n / (4 * Vec::size ());
85
85
if (count > 0 ) {
86
- vectorized_reduction< func_t , vec_func_t , true > (data, count, vector_stride, op, vop);
86
+ vectorized_reduction (data, count, vector_stride, op, vop, /* reduce= */ true );
87
87
}
88
88
char * ptrs[3 ] = { data[0 ], data[0 ], data[1 ] };
89
89
int64_t strides[] = { 0 , 0 , sizeof (scalar_t ) };
@@ -99,7 +99,7 @@ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_
99
99
constexpr int64_t vector_stride = 4 * Vec::size () * sizeof (scalar_t );
100
100
int64_t outer_stride[2 ] = { vector_stride, vector_stride };
101
101
UNARY_OUTER_LOOP (data, outer_stride, size1 / (4 * Vec::size ()), [&] {
102
- vectorized_reduction< func_t , vec_func_t , false > (data, size0, inner_stride, op, vop);
102
+ vectorized_reduction (data, size0, inner_stride, op, vop, /* reduce= */ false );
103
103
});
104
104
105
105
// reduce down the remaining columns
0 commit comments