Skip to content

Commit a951d99

Browse files
Revert "Move reduce to template parameter in vectorized_reduction (pytorch#138672)"
This reverts commit 9b2c99d. Reverted pytorch#138672 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#138672 (comment)))
1 parent 9bbe4a6 commit a951d99

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

aten/src/ATen/native/cpu/Reduce.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ inline bool is_outer_reduction(const int64_t* strides) {
3333
strides[3] == sizeof(typename traits::arg2_t);
3434
}
3535

36-
template <typename func_t, typename vec_func_t, bool reduce>
36+
template <typename func_t, typename vec_func_t>
3737
inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
38-
func_t op [[maybe_unused]], vec_func_t vop) {
38+
func_t op, vec_func_t vop, bool reduce) {
3939
VEC_LOOP_HEADER(func_t, data)
4040
const char* in1_ptr = data[1];
4141
Vec acc[4];
@@ -49,7 +49,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
4949
acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
5050
acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
5151
}
52-
if constexpr (reduce) {
52+
if (reduce) {
5353
scalar_t buffer[Vec::size()];
5454
acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
5555
acc[0].store(buffer);
@@ -83,7 +83,7 @@ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_fu
8383
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
8484
int64_t count = n / (4 * Vec::size());
8585
if (count > 0) {
86-
vectorized_reduction<func_t, vec_func_t, true>(data, count, vector_stride, op, vop);
86+
vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
8787
}
8888
char* ptrs[3] = { data[0], data[0], data[1] };
8989
int64_t strides[] = { 0, 0, sizeof(scalar_t) };
@@ -99,7 +99,7 @@ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_
9999
constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
100100
int64_t outer_stride[2] = { vector_stride, vector_stride };
101101
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
102-
vectorized_reduction<func_t, vec_func_t, false>(data, size0, inner_stride, op, vop);
102+
vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
103103
});
104104

105105
// reduce down the remaining columns

0 commit comments

Comments
 (0)