@@ -21,6 +21,7 @@ namespace {
21
21
22
22
template <typename CTYPE_IN, typename CTYPE_OUT>
23
23
void compute_variance (
24
+ KernelRuntimeContext& ctx,
24
25
const Tensor& in,
25
26
Tensor& out,
26
27
optional<ArrayRef<int64_t >> dim_list,
@@ -33,22 +34,26 @@ void compute_variance(
33
34
}
34
35
} else {
35
36
MapReduceOverDimListPlan plan (in, dim_list);
36
- for (const auto out_ix : c10::irange (out.numel ())) {
37
- CTYPE_OUT sum = plan.execute <CTYPE_IN, CTYPE_OUT>(
38
- [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
39
- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
40
- out_ix);
41
- CTYPE_OUT mean = sum / static_cast <CTYPE_OUT>(num);
42
- CTYPE_OUT sum2 = plan.execute <CTYPE_IN, CTYPE_OUT>(
43
- [mean](CTYPE_IN v) {
44
- return (
45
- (static_cast <CTYPE_OUT>(v) - mean) *
46
- (static_cast <CTYPE_OUT>(v) - mean));
47
- },
48
- [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
49
- out_ix);
50
- out_data[out_ix] = sum2 / denominator;
51
- }
37
+ const bool success = parallel_for_each_reduce_over_dim_list_output_index (
38
+ in, dim_list, out, [&](const auto begin, const auto end) {
39
+ for (const auto out_ix : c10::irange (begin, end)) {
40
+ CTYPE_OUT sum = plan.execute <CTYPE_IN, CTYPE_OUT>(
41
+ [](CTYPE_IN v) { return static_cast <CTYPE_OUT>(v); },
42
+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
43
+ out_ix);
44
+ CTYPE_OUT mean = sum / static_cast <CTYPE_OUT>(num);
45
+ CTYPE_OUT sum2 = plan.execute <CTYPE_IN, CTYPE_OUT>(
46
+ [mean](CTYPE_IN v) {
47
+ return (
48
+ (static_cast <CTYPE_OUT>(v) - mean) *
49
+ (static_cast <CTYPE_OUT>(v) - mean));
50
+ },
51
+ [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
52
+ out_ix);
53
+ out_data[out_ix] = sum2 / denominator;
54
+ }
55
+ });
56
+ ET_KERNEL_CHECK_MSG (ctx, success, Internal, , " parallel_for failed" );
52
57
}
53
58
}
54
59
@@ -90,7 +95,7 @@ Tensor& var_out(
90
95
91
96
ET_SWITCH_FLOATHBF16_TYPES (in.scalar_type (), ctx, name, CTYPE_IN, [&] {
92
97
ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, name, CTYPE_OUT, [&] {
93
- compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
98
+ compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
94
99
});
95
100
});
96
101
@@ -135,7 +140,7 @@ Tensor& var_correction_out(
135
140
136
141
ET_SWITCH_FLOATHBF16_TYPES (in.scalar_type (), ctx, name, CTYPE_IN, [&] {
137
142
ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, name, CTYPE_OUT, [&] {
138
- compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
143
+ compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
139
144
});
140
145
});
141
146
0 commit comments