Skip to content

Commit 45fbf9a

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: e4dd38e50bdd667fd3287dc90f7040cccae83993 ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent 4427eef commit 45fbf9a

File tree

5 files changed

+105
-5
lines changed

5 files changed

+105
-5
lines changed

.lintrunner.toml

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ exclude_patterns = [
264264
'examples/**',
265265
'exir/verification/bindings.cpp',
266266
'extension/**',
267+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
268+
'kernels/portable/cpu/util/elementwise_util.h',
267269
'kernels/optimized/**',
268270
'runtime/core/exec_aten/**',
269271
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_mul.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
5656
CTYPE_COMPUTE,
5757
op_name,
5858
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
59+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
6260
ctx,
6361
a,
6462
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/util/elementwise_util.h

+97-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1616
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1717

18+
#ifdef ET_USE_PYTORCH_HEADERS
19+
#include <ATen/cpu/vec/vec.h>
20+
#endif // ET_USE_PYTORCH_HEADERS
21+
1822
#include <array>
1923
#include <utility>
2024

@@ -51,6 +55,22 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5155
}
5256

5357
namespace internal {
58+
template <typename Ignore, typename T>
59+
using ignore_first_yield_second = T;
60+
61+
#ifdef ET_USE_PYTORCH_HEADERS
62+
// Can I call a function of type Op with sizeof...(Args) arguments of type
63+
// at::vec::Vectorized<CTYPE_COMMON>?
64+
//
65+
// See [NOTE: Generic lambdas] below for requirements on Op.
66+
template <typename CTYPE_COMMON, typename Op, typename... Args>
67+
constexpr bool can_use_vectorized() {
68+
return std::is_invocable_v<
69+
Op,
70+
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
71+
}
72+
#endif // ET_USE_PYTORCH_HEADERS
73+
5474
template <
5575
typename CTYPE_COMMON,
5676
typename CTYPE_OUT,
@@ -61,14 +81,72 @@ inline void dtype_specialized_elementwise_fn_impl(
6181
KernelRuntimeContext& ctx,
6282
const Tensor& out,
6383
Args... inputs) {
84+
static_assert(
85+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
86+
...));
6487
constexpr auto kNumInputs = sizeof...(inputs);
65-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
88+
// All inputs must be of type CTYPE_COMMON.
89+
ET_DCHECK(
90+
((inputs.first->scalar_type() ==
91+
CppTypeToScalarType<CTYPE_COMMON>::value) &&
92+
...));
6693

6794
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
6895
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
6996

7097
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
7198

99+
#ifdef ET_USE_PYTORCH_HEADERS
100+
if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>()) {
101+
const bool any_is_broadcasted =
102+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
103+
inputs.first->sizes(), out.sizes()) &&
104+
...);
105+
if (!any_is_broadcasted) {
106+
using Vec = at::vec::Vectorized<CTYPE_COMMON>;
107+
::executorch::extension::parallel_for(
108+
0,
109+
out.numel(),
110+
::executorch::extension::internal::GRAIN_SIZE,
111+
[&](const auto begin, const auto end) {
112+
const auto vectorized_begin =
113+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
114+
const auto vectorized_end = end - (end % Vec::size());
115+
// Scalar prologue.
116+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
117+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
118+
for (const auto input_idx : c10::irange(kNumInputs)) {
119+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
120+
}
121+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
122+
}
123+
124+
// Main vectorized loop.
125+
for (auto idx = vectorized_begin; idx < vectorized_end;
126+
idx += Vec::size()) {
127+
std::array<Vec, kNumInputs> loaded_vec_inputs;
128+
for (const auto input_idx : c10::irange(kNumInputs)) {
129+
loaded_vec_inputs[input_idx] =
130+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
131+
}
132+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
133+
result_vec.store(&data_out[idx]);
134+
}
135+
136+
// Scalar epilogue.
137+
for (const auto idx : c10::irange(vectorized_end, end)) {
138+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
139+
for (const auto input_idx : c10::irange(kNumInputs)) {
140+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
141+
}
142+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
143+
}
144+
});
145+
return;
146+
}
147+
}
148+
#endif
149+
72150
::executorch::extension::parallel_for(
73151
0,
74152
out.numel(),
@@ -240,6 +318,19 @@ inline void apply_unitensor_elementwise_fn(
240318
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
241319
}
242320

321+
/**
322+
* Useful for unary elementwise operators. For each element of the
323+
* input, call Op and write to the corresponding element of the
324+
* output. Tensor broadcasting is applied wherever it is required.
325+
*
326+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
327+
* parameters; normal lambdas are fine), it must fulfill one of the
328+
* following conditions. Either:
329+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
330+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in
331+
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
332+
* .
333+
*/
243334
template <
244335
typename CTYPE_COMMON,
245336
const char* op_name,
@@ -281,6 +372,8 @@ inline void apply_bitensor_elementwise_fn(
281372
* Useful for bi-tensor elementwise operators. For each element of the inputs,
282373
* perform a computation and write to the corresponding element of the output.
283374
* Tensor broadcasting is applied wherever it is required.
375+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
376+
* compute_fun.
284377
*/
285378
template <
286379
typename CTYPE_COMMON,
@@ -347,6 +440,9 @@ inline void apply_tritensor_elementwise_fn(
347440
*
348441
* static constexpr const char op_name[] = "my_op";
349442
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
443+
*
444+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
445+
* compute_fun.
350446
*/
351447
template <
352448
typename CTYPE_COMMON,

kernels/portable/cpu/util/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def define_common_targets():
110110
":broadcast_indexes_range",
111111
":broadcast_util",
112112
":dtype_util",
113+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113114
"//executorch/runtime/kernel:kernel_runtime_context",
114115
"//executorch/runtime/kernel:thread_parallel_interface",
115116
],

runtime/core/portable_type/c10/c10/targets.bzl

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def define_common_targets():
4949
runtime.cxx_library(
5050
name = "aten_headers_for_executorch",
5151
srcs = [],
52-
visibility = ["//executorch/kernels/optimized/..."],
52+
visibility = [
53+
"//executorch/kernels/optimized/...",
54+
"//executorch/kernels/portable/cpu/util/...",
55+
],
5356
exported_deps = select({
5457
"DEFAULT": [],
5558
"ovr_config//cpu:arm64": [

0 commit comments

Comments
 (0)