Skip to content

Commit 846f827

Browse files
committed
Add basic parallel_for support to reduce_util
Initial parallel_for integration in a portable op. Needed for #8932. Feel free to hold review until rest of stack is ready and we observe successful paralleliztaion. ghstack-source-id: 3d510f0abf35069c3c3939605ff9c5639f8f845d ghstack-comment-id: 2702502530 Pull Request resolved: #8986
1 parent b9b44b9 commit 846f827

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

kernels/portable/cpu/util/reduce_util.h

+14-6
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
#pragma once
1010

11+
#include <c10/util/irange.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1315
#include <cstring>
1416
#include <tuple>
1517

@@ -24,9 +26,12 @@ void apply_on_flat_ix_with_stride_and_base(
2426
const size_t base,
2527
const size_t start,
2628
const size_t end) {
27-
for (size_t i = start; i <= end; i++) {
28-
fn(base + i * stride);
29-
}
29+
executorch::extension::parallel_for(
30+
start, end + 1, [&](auto start_, auto end_) {
31+
for (const auto i : c10::irange(start_, end_)) {
32+
fn(base + i * stride);
33+
}
34+
});
3035
}
3136

3237
template <typename Fn>
@@ -36,9 +41,12 @@ void apply_on_flat_and_dim_ix_with_stride_and_base(
3641
const size_t base,
3742
const size_t start,
3843
const size_t end) {
39-
for (size_t i = start; i <= end; i++) {
40-
fn(base + i * stride, i);
41-
}
44+
executorch::extension::parallel_for(
45+
start, end + 1, [&](auto start_, auto end_) {
46+
for (const auto i : c10::irange(start_, end_)) {
47+
fn(base + i * stride, i);
48+
}
49+
});
4250
}
4351

4452
template <typename Fn>

kernels/portable/cpu/util/targets.bzl

+5-1
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,12 @@ def define_common_targets():
299299
srcs = ["reduce_util.cpp"],
300300
exported_headers = ["reduce_util.h"],
301301
deps = [
302-
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
303302
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
303+
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
304+
],
305+
exported_deps = [
306+
"//executorch/runtime/kernel:thread_parallel_interface",
307+
"//executorch/runtime/core/portable_type/c10/c10:c10",
304308
],
305309
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
306310
visibility = [

runtime/kernel/thread_parallel_interface.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ inline bool parallel_for_no_threadpool(
3333
return true;
3434
}
3535

36+
// Match GRAIN_SIZE from PyTorch core.
37+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.h#L78
38+
constexpr int64_t GRAIN_SIZE = 32768;
39+
3640
} // namespace internal
3741

3842
#ifdef ET_USE_THREADPOOL
@@ -74,10 +78,18 @@ inline int64_t get_thread_num() {
7478
return 0;
7579
}
7680

77-
void set_thread_num(int64_t thread_num) {
81+
inline void set_thread_num(int64_t thread_num) {
7882
ET_DCHECK_MSG(false, "cannot set_thread_num without threading support!");
7983
}
8084
#endif // ET_USE_THREADPOOL
85+
86+
/**
87+
* Convenience version of parallel_for that sets the grain size to internal::GRAIN_SIZE.
88+
*/
89+
template <typename Func>
90+
bool parallel_for(const int64_t begin, const int64_t end, const Func& func) {
91+
return parallel_for(begin, end, internal::GRAIN_SIZE, func);
92+
}
8193
} // namespace extension
8294
} // namespace executorch
8395

0 commit comments

Comments
 (0)