Skip to content

Commit dd9a85a

Browse files
authored
Use parallel_for_each_reduce_over_dim_list_output_index for {Map,}ReduceOverDimListPlan ops (#9197)
Another straightforward rollout.
1 parent fbed0b2 commit dd9a85a

File tree

6 files changed

+100
-73
lines changed

6 files changed

+100
-73
lines changed

kernels/portable/cpu/op_amax.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,17 @@ Tensor& amax_out(
4646
ReduceOverDimListPlan plan(in, dim_list);
4747
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
4848
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
49-
for (const auto out_ix : c10::irange(out.numel())) {
50-
out_data[out_ix] = plan.execute<CTYPE>(
51-
[](CTYPE v, CTYPE max_v) {
52-
return std::isnan(v) || v > max_v ? v : max_v;
53-
},
54-
out_ix);
55-
}
49+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
50+
in, dim_list, out, [&](const auto begin, const auto end) {
51+
for (const auto out_ix : c10::irange(begin, end)) {
52+
out_data[out_ix] = plan.execute<CTYPE>(
53+
[](CTYPE v, CTYPE max_v) {
54+
return std::isnan(v) || v > max_v ? v : max_v;
55+
},
56+
out_ix);
57+
}
58+
});
59+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
5660
});
5761

5862
return out;

kernels/portable/cpu/op_amin.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ Tensor& amin_out(
4545
ReduceOverDimListPlan plan(in, dim_list);
4646
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
4747
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
48-
for (const auto out_ix : c10::irange(out.numel())) {
49-
out_data[out_ix] = plan.execute<CTYPE>(
50-
[](CTYPE v, CTYPE min_v) {
51-
return std::isnan(v) || v < min_v ? v : min_v;
52-
},
53-
out_ix);
54-
}
48+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
49+
in, dim_list, out, [&](const auto begin, const auto end) {
50+
for (const auto out_ix : c10::irange(begin, end)) {
51+
out_data[out_ix] = plan.execute<CTYPE>(
52+
[](CTYPE v, CTYPE min_v) {
53+
return std::isnan(v) || v < min_v ? v : min_v;
54+
},
55+
out_ix);
56+
}
57+
});
58+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
5559
});
5660

5761
return out;

kernels/portable/cpu/op_any.cpp

+15-10
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,21 @@ Tensor& any_dims_out(
9696
static_cast<CTYPE_OUT>(static_cast<bool>(in_data[out_ix]));
9797
}
9898
} else {
99-
for (const auto out_ix : c10::irange(out.numel())) {
100-
bool any = false;
101-
if (in_not_empty) {
102-
any = plan->execute<CTYPE_IN, bool>(
103-
[](CTYPE_IN v) { return static_cast<bool>(v); },
104-
[](bool outv, bool acc) { return acc || outv; },
105-
out_ix);
106-
}
107-
out_data[out_ix] = static_cast<CTYPE_OUT>(any);
108-
}
99+
const bool success =
100+
parallel_for_each_reduce_over_dim_list_output_index(
101+
in, dim_list, out, [&](const auto begin, const auto end) {
102+
for (const auto out_ix : c10::irange(begin, end)) {
103+
bool any = false;
104+
if (in_not_empty) {
105+
any = plan->execute<CTYPE_IN, bool>(
106+
[](CTYPE_IN v) { return static_cast<bool>(v); },
107+
[](bool outv, bool acc) { return acc || outv; },
108+
out_ix);
109+
}
110+
out_data[out_ix] = static_cast<CTYPE_OUT>(any);
111+
}
112+
});
113+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
109114
}
110115
});
111116
});

kernels/portable/cpu/op_mean.cpp

+20-15
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,27 @@ Tensor& mean_dim_out(
4646
out);
4747

4848
MapReduceOverDimListPlan plan(in, dim_list);
49-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
50-
ET_SWITCH_FLOATHBF16_TYPES(
51-
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
52-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
53-
const size_t num = get_reduced_dim_product(in, dim_list);
54-
for (const auto out_ix : c10::irange(out.numel())) {
55-
CTYPE_OUT sum = 0;
56-
if (in.numel() > 0) {
57-
sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
58-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
59-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
60-
out_ix);
49+
// @lint-ignore CLANGTIDY facebook-hte-CArray
50+
static constexpr const char op_name[] = "add.out";
51+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
52+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
53+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
54+
const size_t num = get_reduced_dim_product(in, dim_list);
55+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
56+
in, dim_list, out, [&](const auto begin, const auto end) {
57+
for (const auto out_ix : c10::irange(begin, end)) {
58+
CTYPE_OUT sum = 0;
59+
if (in.numel() > 0) {
60+
sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
61+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
62+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
63+
out_ix);
64+
}
65+
out_data[out_ix] = sum / static_cast<float>(num);
6166
}
62-
out_data[out_ix] = sum / static_cast<float>(num);
63-
}
64-
});
67+
});
68+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
69+
});
6570
});
6671

6772
return out;

kernels/portable/cpu/op_sum.cpp

+20-16
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,27 @@ Tensor& sum_dim_out(
5050
if (in.numel() > 0) {
5151
plan.emplace(in, dim_list);
5252
}
53-
ET_SWITCH_REALHBBF16_TYPES(
54-
in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] {
55-
ET_SWITCH_REALHBBF16_TYPES(
56-
out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] {
57-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
58-
for (const auto out_ix : c10::irange(out.numel())) {
59-
CTYPE_OUT sum = 0;
60-
if (plan.has_value()) {
61-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
62-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
63-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
64-
out_ix);
65-
}
66-
out_data[out_ix] = sum;
53+
// @lint-ignore CLANGTIDY facebook-hte-CArray
54+
static constexpr const char op_name[] = "sum.IntList_out";
55+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
56+
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
57+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
58+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
59+
in, dim_list, out, [&](const auto begin, const auto end) {
60+
for (const auto out_ix : c10::irange(begin, end)) {
61+
CTYPE_OUT sum = 0;
62+
if (plan.has_value()) {
63+
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
64+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
65+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
66+
out_ix);
6767
}
68-
});
69-
});
68+
out_data[out_ix] = sum;
69+
}
70+
});
71+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
72+
});
73+
});
7074

7175
return out;
7276
}

kernels/portable/cpu/op_var.cpp

+23-18
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace {
2121

2222
template <typename CTYPE_IN, typename CTYPE_OUT>
2323
void compute_variance(
24+
KernelRuntimeContext& ctx,
2425
const Tensor& in,
2526
Tensor& out,
2627
optional<ArrayRef<int64_t>> dim_list,
@@ -33,22 +34,26 @@ void compute_variance(
3334
}
3435
} else {
3536
MapReduceOverDimListPlan plan(in, dim_list);
36-
for (const auto out_ix : c10::irange(out.numel())) {
37-
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
38-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
39-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
40-
out_ix);
41-
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
42-
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
43-
[mean](CTYPE_IN v) {
44-
return (
45-
(static_cast<CTYPE_OUT>(v) - mean) *
46-
(static_cast<CTYPE_OUT>(v) - mean));
47-
},
48-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
49-
out_ix);
50-
out_data[out_ix] = sum2 / denominator;
51-
}
37+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
38+
in, dim_list, out, [&](const auto begin, const auto end) {
39+
for (const auto out_ix : c10::irange(begin, end)) {
40+
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
41+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
42+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
43+
out_ix);
44+
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
45+
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
46+
[mean](CTYPE_IN v) {
47+
return (
48+
(static_cast<CTYPE_OUT>(v) - mean) *
49+
(static_cast<CTYPE_OUT>(v) - mean));
50+
},
51+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
52+
out_ix);
53+
out_data[out_ix] = sum2 / denominator;
54+
}
55+
});
56+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
5257
}
5358
}
5459

@@ -90,7 +95,7 @@ Tensor& var_out(
9095

9196
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
9297
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
93-
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
98+
compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
9499
});
95100
});
96101

@@ -135,7 +140,7 @@ Tensor& var_correction_out(
135140

136141
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
137142
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
138-
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
143+
compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
139144
});
140145
});
141146

0 commit comments

Comments
 (0)