Skip to content

Commit 4a30c0e

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: 033b63ce3bee8a0136efdab3e03905cafb79b915 ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent 15cfb59 commit 4a30c0e

15 files changed

+282
-21
lines changed

.lintrunner.toml

+4
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ exclude_patterns = [
264264
'examples/**',
265265
'exir/verification/bindings.cpp',
266266
'extension/**',
267+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
268+
'kernels/portable/cpu/util/elementwise_util.h',
269+
'kernels/portable/cpu/util/math_util.h',
270+
'kernels/portable/cpu/util/vectorized_math.h',
267271
'kernels/optimized/**',
268272
'runtime/core/exec_aten/**',
269273
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_atan2.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Tensor& atan2_out(
6060
op_name,
6161
utils::SupportedTensorDtypes::FLOATHBF16>(
6262
[](const auto val_a, const auto val_b) {
63-
return std::atan2(val_a, val_b);
63+
return executorch::math::atan2(val_a, val_b);
6464
},
6565
ctx,
6666
a,

kernels/portable/cpu/op_elu.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ Tensor& elu_out(
4848
CTYPE,
4949
op_name,
5050
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
51-
[negcoef, math_scale, math_input_scale](const auto x) {
52-
// TODO: rewrite this to be vectorization-capable.
51+
[negcoef, math_scale, math_input_scale](const CTYPE x) {
5352
return MathT(x) <= MathT(0)
5453
? std::expm1(MathT(x) * math_input_scale) * negcoef
5554
: MathT(x) * math_scale;

kernels/portable/cpu/op_fmod.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Tensor& fmod_Tensor_out(
6161
utils::SupportedTensorDtypes::REALHBF16>(
6262
[&div_by_zero_error](
6363
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
64-
// TODO: rewrite this to be vectorization-capable.
64+
// TODO: rewrite this to be vectorization-capable?
6565
CTYPE_COMPUTE value = 0;
6666
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6767
if (val_b == 0) {
@@ -138,10 +138,8 @@ Tensor& fmod_Scalar_out(
138138
CTYPE_COMPUTE,
139139
op_name,
140140
utils::SupportedTensorDtypes::REALHBF16>(
141-
[val_b](const CTYPE_COMPUTE val_a) {
142-
// TODO: rewrite this to be vectorization-capable.
143-
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
144-
return value;
141+
[val_b](const auto val_a) {
142+
return executorch::math::fmod(val_a, (decltype(val_a))val_b);
145143
},
146144
ctx,
147145
a,

kernels/portable/cpu/op_maximum.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Tensor& maximum_out(
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::REALHBBF16>(
52-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
52+
[](const auto val_a, const auto val_b) {
5353
return utils::max_override(val_a, val_b);
5454
},
5555
ctx,

kernels/portable/cpu/op_minimum.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ Tensor& minimum_out(
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::REALHBBF16>(
52-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
53-
// TODO: rewrite this to be vectorization-capable.
52+
[](const auto val_a, const auto val_b) {
5453
return utils::min_override(val_a, val_b);
5554
},
5655
ctx,

kernels/portable/cpu/op_mul.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
5656
CTYPE_COMPUTE,
5757
op_name,
5858
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
59+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
6260
ctx,
6361
a,
6462
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/op_pow.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ Tensor& pow_Tensor_Tensor_out(
5757
CTYPE_COMPUTE,
5858
op_name,
5959
utils::SupportedTensorDtypes::REALHBF16>(
60-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60+
[](const auto val_a, const auto val_b) {
6161
// TODO: rewrite this to be vectorization-capable.
62-
return std::pow(val_a, val_b);
62+
return executorch::math::pow(val_a, val_b);
6363
},
6464
ctx,
6565
a,

kernels/portable/cpu/op_sigmoid.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::FLOATHBF16>(
52-
[](const auto val_in) -> CTYPE_COMPUTE {
53-
// TODO: rewrite this to be vectorization-capable
52+
[](const CTYPE_COMPUTE val_in) {
53+
// TODO: rewrite this to be vectorization-capable; need
54+
// unary - overload for Vectorized.
5455
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
5556
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
5657
return out_val;

kernels/portable/cpu/op_where.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Tensor& where_out(
4747
CTYPE_COMPUTE,
4848
op_name,
4949
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
50-
[](const auto val_a, const auto val_b, const auto val_c) {
50+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b, const CTYPE_COMPUTE val_c) {
5151
return val_c ? val_a : val_b;
5252
},
5353
ctx,

kernels/portable/cpu/util/elementwise_util.h

+115-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1414
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
15+
#include <executorch/kernels/portable/cpu/util/vectorized_math.h> // Make vectorization support easy for clients.
1516
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1617
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1718

19+
#ifdef ET_USE_PYTORCH_HEADERS
20+
#include <ATen/cpu/vec/vec.h>
21+
#endif // ET_USE_PYTORCH_HEADERS
22+
1823
#include <array>
1924
#include <utility>
2025

@@ -51,6 +56,34 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5156
}
5257

5358
namespace internal {
59+
template <typename Ignore, typename T>
60+
using ignore_first_yield_second = T;
61+
62+
#ifdef ET_USE_PYTORCH_HEADERS
63+
// Can I call a function of type Op with sizeof...(Args) arguments of type
64+
// at::vec::Vectorized<CTYPE_COMPUTE>?
65+
//
66+
// See [NOTE: Generic lambdas] below for requirements on Op.
67+
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
68+
constexpr bool can_use_vectorized() {
69+
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
70+
if constexpr (std::is_invocable_v<
71+
Op,
72+
ignore_first_yield_second<Args, Vec>...>) {
73+
// For bool, we will get a false positive if we rely on only the
74+
// is_invocable_v check above because at::vec::Vectorized is
75+
// implicitly convertible to a pointer, which makes it implicitly
76+
// convertible to bool (which was 15 minutes of fun to debug). Also
77+
// just seems like good hygiene to make sure we get the Vectorized
78+
// we're expecting.
79+
return std::is_same_v<
80+
std::invoke_result_t<Op, ignore_first_yield_second<Args, Vec>...>,
81+
Vec>;
82+
}
83+
return false;
84+
}
85+
#endif // ET_USE_PYTORCH_HEADERS
86+
5487
template <
5588
typename CTYPE_COMPUTE,
5689
typename CTYPE_OUT,
@@ -61,8 +94,71 @@ inline void dtype_specialized_elementwise_fn_impl(
6194
KernelRuntimeContext& ctx,
6295
const Tensor& out,
6396
Args... inputs) {
97+
static_assert(
98+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
99+
...));
64100
constexpr auto kNumInputs = sizeof...(inputs);
65-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));
101+
// All inputs must be of type CTYPE_COMPUTE.
102+
ET_DCHECK(
103+
((inputs.first->scalar_type() ==
104+
CppTypeToScalarType<CTYPE_COMPUTE>::value) &&
105+
...));
106+
107+
#ifdef ET_USE_PYTORCH_HEADERS
108+
if constexpr (can_use_vectorized<CTYPE_COMPUTE, Op, Args...>()) {
109+
const bool any_is_broadcasted =
110+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
111+
inputs.first->sizes(), out.sizes()) &&
112+
...);
113+
if (!any_is_broadcasted) {
114+
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
115+
::executorch::extension::parallel_for(
116+
0,
117+
out.numel(),
118+
::executorch::extension::internal::GRAIN_SIZE,
119+
[&](const auto begin, const auto end) {
120+
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
121+
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
122+
123+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
124+
125+
const auto vectorized_begin =
126+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
127+
const auto vectorized_end = end - (end % Vec::size());
128+
// Scalar prologue.
129+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
130+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
131+
for (const auto input_idx : c10::irange(kNumInputs)) {
132+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
133+
}
134+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
135+
}
136+
137+
// Main vectorized loop.
138+
for (auto idx = vectorized_begin; idx < vectorized_end;
139+
idx += Vec::size()) {
140+
std::array<Vec, kNumInputs> loaded_vec_inputs;
141+
for (const auto input_idx : c10::irange(kNumInputs)) {
142+
loaded_vec_inputs[input_idx] =
143+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
144+
}
145+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
146+
result_vec.store(&data_out[idx]);
147+
}
148+
149+
// Scalar epilogue.
150+
for (const auto idx : c10::irange(vectorized_end, end)) {
151+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
152+
for (const auto input_idx : c10::irange(kNumInputs)) {
153+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
154+
}
155+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
156+
}
157+
});
158+
return;
159+
}
160+
}
161+
#endif
66162

67163
::executorch::extension::parallel_for(
68164
0,
@@ -240,6 +336,19 @@ inline void apply_unitensor_elementwise_fn(
240336
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
241337
}
242338

339+
/**
340+
* Useful for unary elementwise operators. For each element of the
341+
* input, call Op and write to the corresponding element of the
342+
* output. Tensor broadcasting is applied wherever it is required.
343+
*
344+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
345+
* parameters; normal lambdas are fine), it must fulfill one of the
346+
* following conditions. Either:
347+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMPUTE>, or
348+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in
349+
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
350+
* .
351+
*/
243352
template <
244353
typename CTYPE_COMPUTE,
245354
const char* op_name,
@@ -281,6 +390,8 @@ inline void apply_bitensor_elementwise_fn(
281390
* Useful for bi-tensor elementwise operators. For each element of the inputs,
282391
* perform a computation and write to the corresponding element of the output.
283392
* Tensor broadcasting is applied wherever it is required.
393+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
394+
* compute_fun.
284395
*/
285396
template <
286397
typename CTYPE_COMPUTE,
@@ -347,6 +458,9 @@ inline void apply_tritensor_elementwise_fn(
347458
*
348459
* static constexpr const char op_name[] = "my_op";
349460
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
461+
*
462+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
463+
* compute_fun.
350464
*/
351465
template <
352466
typename CTYPE_COMPUTE,

kernels/portable/cpu/util/math_util.h

+19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
#pragma once
1010

11+
#ifdef ET_USE_PYTORCH_HEADERS
12+
#include <ATen/cpu/vec/vec.h>
13+
#endif
14+
1115
namespace torch {
1216
namespace executor {
1317
namespace native {
@@ -138,6 +142,21 @@ T max_override(T a, T b) {
138142
return b;
139143
}
140144

145+
#ifdef ET_USE_PYTORCH_HEADERS
146+
template <typename T>
147+
at::vec::Vectorized<T> min_override(
148+
at::vec::Vectorized<T> a,
149+
at::vec::Vectorized<T> b) {
150+
return at::vec::minimum(a, b);
151+
}
152+
153+
template <typename T>
154+
at::vec::Vectorized<T> max_override(
155+
at::vec::Vectorized<T> a,
156+
at::vec::Vectorized<T> b) {
157+
return at::vec::maximum(a, b);
158+
}
159+
#endif
141160
/**
142161
* There is a slight difference in how std::fmod works compared to how ATen
143162
* determines remainders:

kernels/portable/cpu/util/targets.bzl

+15
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def define_common_targets():
3232
"//executorch/kernels/portable/cpu/util:slice_util",
3333
"//executorch/kernels/portable/cpu/util:elementwise_util",
3434
"//executorch/kernels/portable/cpu/util:upsample_util",
35+
"//executorch/kernels/portable/cpu/util:vectorized_math",
3536
"//executorch/runtime/kernel:thread_parallel_interface",
3637
],
3738
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
@@ -110,6 +111,8 @@ def define_common_targets():
110111
":broadcast_indexes_range",
111112
":broadcast_util",
112113
":dtype_util",
114+
":vectorized_math",
115+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113116
"//executorch/runtime/kernel:kernel_runtime_context",
114117
"//executorch/runtime/kernel:thread_parallel_interface",
115118
],
@@ -260,6 +263,9 @@ def define_common_targets():
260263
srcs = [],
261264
exported_headers = ["math_util.h"],
262265
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."],
266+
exported_deps = [
267+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
268+
],
263269
)
264270

265271
runtime.cxx_library(
@@ -307,6 +313,15 @@ def define_common_targets():
307313
],
308314
)
309315

316+
runtime.cxx_library(
317+
name = "vectorized_math",
318+
exported_headers = ["vectorized_math.h"],
319+
visibility = ["//executorch/..."],
320+
exported_deps = [
321+
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
322+
],
323+
)
324+
310325
# Utility functions that can be used by operators that perform reduction
311326
for aten_mode in get_aten_mode_options():
312327
suffix = "_aten" if aten_mode else ""

0 commit comments

Comments
 (0)