Skip to content

Commit a28e9be

Browse files
authored
Use parallel_for in elementwise_util (#9243)
More straightforward rollout. (Parallelizing over BroadcastIndexesRange is ugly, but so far we have exactly two instances of it. If it keeps cropping up, we can add a utility to make it nicer.)
1 parent d8c069a commit a28e9be

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

kernels/portable/cpu/util/elementwise_util.h

+23-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1414
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
16+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1617

1718
#include <array>
1819
#include <utility>
@@ -94,17 +95,28 @@ inline void apply_elementwise_fn(
9495
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
9596
const auto out_element_size = out.element_size();
9697

97-
for (const auto& indexes :
98-
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
99-
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
100-
for (const auto idx : c10::irange(kNumInputs)) {
101-
const auto& input_info = inputs_info[idx];
102-
loaded_inputs[idx] = input_info.load_to_common(
103-
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
104-
}
105-
auto result = std::apply(compute_fun, loaded_inputs);
106-
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
107-
}
98+
::executorch::extension::parallel_for(
99+
0,
100+
out.numel(),
101+
::executorch::extension::internal::GRAIN_SIZE,
102+
[&](const auto begin, const auto end) {
103+
const auto range =
104+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
105+
auto begin_it = range.begin();
106+
begin_it += begin;
107+
for (; (*begin_it)[0] < end; ++begin_it) {
108+
const auto& indexes = *begin_it;
109+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
110+
for (const auto idx : c10::irange(kNumInputs)) {
111+
const auto& input_info = inputs_info[idx];
112+
loaded_inputs[idx] = input_info.load_to_common(
113+
&input_info
114+
.data_ptr[indexes[idx + 1] * input_info.element_size]);
115+
}
116+
auto result = std::apply(compute_fun, loaded_inputs);
117+
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
118+
}
119+
});
108120
}
109121
} // namespace internal
110122

kernels/portable/cpu/util/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def define_common_targets():
111111
":broadcast_util",
112112
":dtype_util",
113113
"//executorch/runtime/kernel:kernel_runtime_context",
114+
"//executorch/runtime/kernel:thread_parallel_interface",
114115
],
115116
deps = [
116117
"//executorch/kernels/portable/cpu:scalar_utils",

0 commit comments

Comments
 (0)