15
15
#include < executorch/runtime/kernel/kernel_runtime_context.h>
16
16
#include < executorch/runtime/kernel/thread_parallel_interface.h>
17
17
18
+ #ifdef ET_USE_PYTORCH_HEADERS
19
+ #include < ATen/cpu/vec/vec.h>
20
+ #endif // ET_USE_PYTORCH_HEADERS
21
+
18
22
#include < array>
19
23
#include < utility>
20
24
@@ -51,6 +55,22 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
51
55
}
52
56
53
57
namespace internal {
58
+ template <typename Ignore, typename T>
59
+ using ignore_first_yield_second = T;
60
+
61
+ #ifdef ET_USE_PYTORCH_HEADERS
62
+ // Can I call a function of type Op with sizeof...(Args) arguments of type
63
+ // at::vec::Vectorized<CTYPE_COMMON>?
64
+ //
65
+ // See [NOTE: Generic lambdas] below for requirements on Op.
66
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
67
+ constexpr bool can_use_vectorized () {
68
+ return std::is_invocable_v<
69
+ Op,
70
+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
71
+ }
72
+ #endif // ET_USE_PYTORCH_HEADERS
73
+
54
74
template <
55
75
typename CTYPE_COMMON,
56
76
typename CTYPE_OUT,
@@ -61,14 +81,72 @@ inline void dtype_specialized_elementwise_fn_impl(
61
81
KernelRuntimeContext& ctx,
62
82
const Tensor& out,
63
83
Args... inputs) {
84
+ static_assert (
85
+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
86
+ ...));
64
87
constexpr auto kNumInputs = sizeof ...(inputs);
65
- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
88
+ // All inputs must be of type CTYPE_COMMON.
89
+ ET_DCHECK (
90
+ ((inputs.first ->scalar_type () ==
91
+ CppTypeToScalarType<CTYPE_COMMON>::value) &&
92
+ ...));
66
93
67
94
std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
68
95
inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
69
96
70
97
CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
71
98
99
+ #ifdef ET_USE_PYTORCH_HEADERS
100
+ if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>()) {
101
+ const bool any_is_broadcasted =
102
+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
103
+ inputs.first ->sizes (), out.sizes ()) &&
104
+ ...);
105
+ if (!any_is_broadcasted) {
106
+ using Vec = at::vec::Vectorized<CTYPE_COMMON>;
107
+ ::executorch::extension::parallel_for (
108
+ 0 ,
109
+ out.numel(),
110
+ ::executorch::extension::internal::GRAIN_SIZE,
111
+ [&](const auto begin, const auto end) {
112
+ const auto vectorized_begin =
113
+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
114
+ const auto vectorized_end = end - (end % Vec::size ());
115
+ // Scalar prologue.
116
+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
117
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
118
+ for (const auto input_idx : c10::irange (kNumInputs )) {
119
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
120
+ }
121
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
122
+ }
123
+
124
+ // Main vectorized loop.
125
+ for (auto idx = vectorized_begin; idx < vectorized_end;
126
+ idx += Vec::size ()) {
127
+ std::array<Vec, kNumInputs > loaded_vec_inputs;
128
+ for (const auto input_idx : c10::irange (kNumInputs )) {
129
+ loaded_vec_inputs[input_idx] =
130
+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
131
+ }
132
+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
133
+ result_vec.store (&data_out[idx]);
134
+ }
135
+
136
+ // Scalar epilogue.
137
+ for (const auto idx : c10::irange (vectorized_end, end)) {
138
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
139
+ for (const auto input_idx : c10::irange (kNumInputs )) {
140
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
141
+ }
142
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
143
+ }
144
+ });
145
+ return ;
146
+ }
147
+ }
148
+ #endif
149
+
72
150
::executorch::extension::parallel_for (
73
151
0 ,
74
152
out.numel(),
@@ -240,6 +318,19 @@ inline void apply_unitensor_elementwise_fn(
240
318
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
241
319
}
242
320
321
+ /* *
322
+ * Useful for unary elementwise operators. For each element of the
323
+ * input, call Op and write to the corresponding element of the
324
+ * output. Tensor broadcasting is applied wherever it is required.
325
+ *
326
+ * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
327
+ * parameters; normal lambdas are fine), it must fulfill one of the
328
+ * following conditions. Either:
329
+ * 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
330
+ * 2) It must be actively SFINAE-friendly, as per the C++17 examples in
331
+ * https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
332
+ * .
333
+ */
243
334
template <
244
335
typename CTYPE_COMMON,
245
336
const char * op_name,
@@ -281,6 +372,8 @@ inline void apply_bitensor_elementwise_fn(
281
372
* Useful for bi-tensor elementwise operators. For each element of the inputs,
282
373
* perform a computation and write to the corresponding element of the output.
283
374
* Tensor broadcasting is applied wherever it is required.
375
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
376
+ * compute_fun.
284
377
*/
285
378
template <
286
379
typename CTYPE_COMMON,
@@ -347,6 +440,9 @@ inline void apply_tritensor_elementwise_fn(
347
440
*
348
441
* static constexpr const char op_name[] = "my_op";
349
442
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
443
+ *
444
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for
445
+ * compute_fun.
350
446
*/
351
447
template <
352
448
typename CTYPE_COMMON,
0 commit comments