12
12
#include < executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
13
13
#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
14
14
#include < executorch/kernels/portable/cpu/util/dtype_util.h>
15
+ #include < executorch/kernels/portable/cpu/util/vectorized_math.h> // Make vectorization support easy for clients.
15
16
#include < executorch/runtime/kernel/kernel_runtime_context.h>
16
17
#include < executorch/runtime/kernel/thread_parallel_interface.h>
17
18
19
+ #ifdef ET_USE_PYTORCH_HEADERS
20
+ #include < ATen/cpu/vec/vec.h>
21
+ #endif // ET_USE_PYTORCH_HEADERS
22
+
18
23
#include < array>
19
24
#include < utility>
20
25
@@ -51,6 +56,34 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
51
56
}
52
57
53
58
namespace internal {
59
+ template <typename Ignore, typename T>
60
+ using ignore_first_yield_second = T;
61
+
62
+ #ifdef ET_USE_PYTORCH_HEADERS
63
+ // Can I call a function of type Op with sizeof...(Args) arguments of type
64
+ // at::vec::Vectorized<CTYPE_COMPUTE>?
65
+ //
66
+ // See [NOTE: Generic lambdas] below for requirements on Op.
67
+ template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
68
+ constexpr bool can_use_vectorized () {
69
+ using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
70
+ if constexpr (std::is_invocable_v<
71
+ Op,
72
+ ignore_first_yield_second<Args, Vec>...>) {
73
+ // For bool, we will get a false positive if we rely on only the
74
+ // is_invocable_v check above because at::vec::Vectorized is
75
+ // implicitly convertible to a pointer, which makes it implicitly
76
+ // convertible to bool (which was 15 minutes of fun to debug). Also
77
+ // just seems like good hygiene to make sure we get the Vectorized
78
+ // we're expecting.
79
+ return std::is_same_v<
80
+ std::invoke_result_t <Op, ignore_first_yield_second<Args, Vec>...>,
81
+ Vec>;
82
+ }
83
+ return false ;
84
+ }
85
+ #endif // ET_USE_PYTORCH_HEADERS
86
+
54
87
template <
55
88
typename CTYPE_COMPUTE,
56
89
typename CTYPE_OUT,
@@ -61,8 +94,71 @@ inline void dtype_specialized_elementwise_fn_impl(
61
94
KernelRuntimeContext& ctx,
62
95
const Tensor& out,
63
96
Args... inputs) {
97
+ static_assert (
98
+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
99
+ ...));
64
100
constexpr auto kNumInputs = sizeof ...(inputs);
65
- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMPUTE)) && ...));
101
+ // All inputs must be of type CTYPE_COMPUTE.
102
+ ET_DCHECK (
103
+ ((inputs.first ->scalar_type () ==
104
+ CppTypeToScalarType<CTYPE_COMPUTE>::value) &&
105
+ ...));
106
+
107
+ #ifdef ET_USE_PYTORCH_HEADERS
108
+ if constexpr (can_use_vectorized<CTYPE_COMPUTE, Op, Args...>()) {
109
+ const bool any_is_broadcasted =
110
+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
111
+ inputs.first ->sizes (), out.sizes ()) &&
112
+ ...);
113
+ if (!any_is_broadcasted) {
114
+ using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
115
+ ::executorch::extension::parallel_for (
116
+ 0 ,
117
+ out.numel(),
118
+ ::executorch::extension::internal::GRAIN_SIZE,
119
+ [&](const auto begin, const auto end) {
120
+ std::array<const CTYPE_COMPUTE*, kNumInputs > inputs_data_ptrs = {
121
+ inputs.first ->template const_data_ptr <CTYPE_COMPUTE>()...};
122
+
123
+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
124
+
125
+ const auto vectorized_begin =
126
+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
127
+ const auto vectorized_end = end - (end % Vec::size ());
128
+ // Scalar prologue.
129
+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
130
+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs;
131
+ for (const auto input_idx : c10::irange (kNumInputs )) {
132
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
133
+ }
134
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
135
+ }
136
+
137
+ // Main vectorized loop.
138
+ for (auto idx = vectorized_begin; idx < vectorized_end;
139
+ idx += Vec::size ()) {
140
+ std::array<Vec, kNumInputs > loaded_vec_inputs;
141
+ for (const auto input_idx : c10::irange (kNumInputs )) {
142
+ loaded_vec_inputs[input_idx] =
143
+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
144
+ }
145
+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
146
+ result_vec.store (&data_out[idx]);
147
+ }
148
+
149
+ // Scalar epilogue.
150
+ for (const auto idx : c10::irange (vectorized_end, end)) {
151
+ std::array<CTYPE_COMPUTE, kNumInputs > loaded_inputs;
152
+ for (const auto input_idx : c10::irange (kNumInputs )) {
153
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
154
+ }
155
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
156
+ }
157
+ });
158
+ return ;
159
+ }
160
+ }
161
+ #endif
66
162
67
163
::executorch::extension::parallel_for (
68
164
0 ,
@@ -240,6 +336,19 @@ inline void apply_unitensor_elementwise_fn(
240
336
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
241
337
}
242
338
339
+ /* *
340
+ * Useful for unary elementwise operators. For each element of the
341
+ * input, call Op and write to the corresponding element of the
342
+ * output. Tensor broadcasting is applied wherever it is required.
343
+ *
344
+ * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
345
+ * parameters; normal lambdas are fine), it must fulfill one of the
346
+ * following conditions. Either:
347
+ * 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMPUTE>, or
348
+ * 2) It must be actively SFINAE-friendly, as per the C++17 examples in
349
+ * https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
350
+ * .
351
+ */
243
352
template <
244
353
typename CTYPE_COMPUTE,
245
354
const char * op_name,
@@ -281,6 +390,8 @@ inline void apply_bitensor_elementwise_fn(
281
390
* Useful for bi-tensor elementwise operators. For each element of the inputs,
282
391
* perform a computation and write to the corresponding element of the output.
283
392
* Tensor broadcasting is applied wherever it is required.
393
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
394
+ * compute_fun.
284
395
*/
285
396
template <
286
397
typename CTYPE_COMPUTE,
@@ -347,6 +458,9 @@ inline void apply_tritensor_elementwise_fn(
347
458
*
348
459
* static constexpr const char op_name[] = "my_op";
349
460
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
461
+ *
462
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
463
+ * compute_fun.
350
464
*/
351
465
template <
352
466
typename CTYPE_COMPUTE,
0 commit comments