|
13 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
|
14 | 14 | #include <executorch/kernels/portable/cpu/util/dtype_util.h>
|
15 | 15 | #include <executorch/runtime/kernel/kernel_runtime_context.h>
|
| 16 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
16 | 17 |
|
17 | 18 | #include <array>
|
18 | 19 | #include <utility>
|
@@ -94,17 +95,28 @@ inline void apply_elementwise_fn(
|
94 | 95 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
|
95 | 96 | const auto out_element_size = out.element_size();
|
96 | 97 |
|
97 |
| - for (const auto& indexes : |
98 |
| - BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) { |
99 |
| - std::array<CTYPE_COMMON, kNumInputs> loaded_inputs; |
100 |
| - for (const auto idx : c10::irange(kNumInputs)) { |
101 |
| - const auto& input_info = inputs_info[idx]; |
102 |
| - loaded_inputs[idx] = input_info.load_to_common( |
103 |
| - &input_info.data_ptr[indexes[idx + 1] * input_info.element_size]); |
104 |
| - } |
105 |
| - auto result = std::apply(compute_fun, loaded_inputs); |
106 |
| - store_common_to_out(result, &data_out[indexes[0] * out_element_size]); |
107 |
| - } |
| 98 | + ::executorch::extension::parallel_for( |
| 99 | + 0, |
| 100 | + out.numel(), |
| 101 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 102 | + [&](const auto begin, const auto end) { |
| 103 | + const auto range = |
| 104 | + BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...); |
| 105 | + auto begin_it = range.begin(); |
| 106 | + begin_it += begin; |
| 107 | + for (; (*begin_it)[0] < end; ++begin_it) { |
| 108 | + const auto& indexes = *begin_it; |
| 109 | + std::array<CTYPE_COMMON, kNumInputs> loaded_inputs; |
| 110 | + for (const auto idx : c10::irange(kNumInputs)) { |
| 111 | + const auto& input_info = inputs_info[idx]; |
| 112 | + loaded_inputs[idx] = input_info.load_to_common( |
| 113 | + &input_info |
| 114 | + .data_ptr[indexes[idx + 1] * input_info.element_size]); |
| 115 | + } |
| 116 | + auto result = std::apply(compute_fun, loaded_inputs); |
| 117 | + store_common_to_out(result, &data_out[indexes[0] * out_element_size]); |
| 118 | + } |
| 119 | + }); |
108 | 120 | }
|
109 | 121 | } // namespace internal
|
110 | 122 |
|
|
0 commit comments