@@ -53,8 +53,43 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
53
53
namespace internal {
54
54
template <
55
55
typename CTYPE_COMMON,
56
+ typename CTYPE_OUT,
56
57
typename Op,
57
- typename ... Args>
58
+ typename ... Args>
59
+ inline void dtype_specialized_elementwise_fn_impl (
60
+ const Op& compute_fun,
61
+ KernelRuntimeContext& ctx,
62
+ const Tensor& out,
63
+ Args... inputs) {
64
+ constexpr auto kNumInputs = sizeof ...(inputs);
65
+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
66
+
67
+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
68
+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
69
+
70
+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
71
+
72
+ ::executorch::extension::parallel_for (
73
+ 0 ,
74
+ out.numel(),
75
+ ::executorch::extension::internal::GRAIN_SIZE,
76
+ [&](const auto begin, const auto end) {
77
+ const auto range =
78
+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
79
+ auto begin_it = range.begin ();
80
+ begin_it += begin;
81
+ for (; (*begin_it)[0 ] < end; ++begin_it) {
82
+ const auto & indexes = *begin_it;
83
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
84
+ for (const auto idx : c10::irange (kNumInputs )) {
85
+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
86
+ }
87
+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
88
+ }
89
+ });
90
+ }
91
+
92
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
58
93
inline bool validate_elementwise_fn_inputs (
59
94
const Op& compute_fun,
60
95
KernelRuntimeContext& ctx,
@@ -73,7 +108,8 @@ inline bool validate_elementwise_fn_inputs(
73
108
ctx,
74
109
(check_input_dtype (inputs, compute_type) && ...) &&
75
110
internal::check_tensor_dtype (out, out_dtypes, compute_type),
76
- InvalidArgument, false );
111
+ InvalidArgument,
112
+ false );
77
113
78
114
return true ;
79
115
}
@@ -83,22 +119,12 @@ template <
83
119
const char * op_name,
84
120
typename Op,
85
121
typename ... Args>
86
- inline void apply_elementwise_fn (
122
+ inline void apply_elementwise_fn_generic_impl (
87
123
const Op& compute_fun,
88
124
KernelRuntimeContext& ctx,
89
125
const Tensor& out,
90
126
SupportedTensorDtypes out_dtypes,
91
127
Args... inputs) {
92
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
93
- compute_fun,
94
- ctx,
95
- out,
96
- out_dtypes,
97
- inputs...);
98
- if (!inputs_valid) {
99
- return ;
100
- }
101
-
102
128
constexpr auto kNumInputs = sizeof ...(inputs);
103
129
104
130
struct InputInfo {
@@ -142,6 +168,63 @@ inline void apply_elementwise_fn(
142
168
}
143
169
});
144
170
}
171
+
172
+ template <
173
+ typename CTYPE_COMMON,
174
+ const char * op_name,
175
+ typename Op,
176
+ typename ... Args>
177
+ inline void apply_elementwise_fn_runtime_out_dtypes (
178
+ const Op& compute_fun,
179
+ KernelRuntimeContext& ctx,
180
+ const Tensor& out,
181
+ SupportedTensorDtypes out_dtypes,
182
+ Args... inputs) {
183
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
184
+ compute_fun, ctx, out, out_dtypes, inputs...);
185
+ if (!inputs_valid) {
186
+ return ;
187
+ }
188
+
189
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
190
+ compute_fun, ctx, out, out_dtypes, inputs...);
191
+ }
192
+
193
+ template <
194
+ typename CTYPE_COMMON,
195
+ const char * op_name,
196
+ SupportedTensorDtypes out_dtypes,
197
+ typename Op,
198
+ typename ... Args>
199
+ inline void apply_elementwise_fn (
200
+ const Op& compute_fun,
201
+ KernelRuntimeContext& ctx,
202
+ const Tensor& out,
203
+ Args... inputs) {
204
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
205
+ compute_fun, ctx, out, out_dtypes, inputs...);
206
+ if (!inputs_valid) {
207
+ return ;
208
+ }
209
+
210
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
211
+ const bool all_inputs_compute_dtype =
212
+ ((inputs.first ->scalar_type () == compute_type) && ...);
213
+
214
+ constexpr ScalarType out_specialized_scalar_type =
215
+ specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
216
+ if (all_inputs_compute_dtype &&
217
+ out.scalar_type () == out_specialized_scalar_type) {
218
+ using CTYPE_OUT =
219
+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
220
+ dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
221
+ compute_fun, ctx, out, inputs...);
222
+ return ;
223
+ }
224
+
225
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
226
+ compute_fun, ctx, out, out_dtypes, inputs...);
227
+ }
145
228
} // namespace internal
146
229
147
230
// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -153,19 +236,23 @@ inline void apply_unitensor_elementwise_fn(
153
236
SupportedTensorDtypes a_dtypes,
154
237
const Tensor& out,
155
238
SupportedTensorDtypes out_dtypes) {
156
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
239
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
157
240
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
158
241
}
159
242
160
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
243
+ template <
244
+ typename CTYPE_COMMON,
245
+ const char * op_name,
246
+ SupportedTensorDtypes out_dtypes,
247
+ typename Op>
161
248
inline void apply_unitensor_elementwise_fn (
162
249
const Op& compute_fun,
163
250
KernelRuntimeContext& ctx,
164
251
const Tensor& a,
165
252
SupportedTensorDtypes a_dtypes,
166
253
const Tensor& out) {
167
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
168
- compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
254
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
255
+ compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
169
256
}
170
257
171
258
/* *
@@ -181,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
181
268
SupportedTensorDtypes b_dtypes,
182
269
const Tensor& out,
183
270
SupportedTensorDtypes out_dtypes) {
184
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
271
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
185
272
compute_fun,
186
273
ctx,
187
274
out,
@@ -195,7 +282,11 @@ inline void apply_bitensor_elementwise_fn(
195
282
* perform a computation and write to the corresponding element of the output.
196
283
* Tensor broadcasting is applied wherever it is required.
197
284
*/
198
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
285
+ template <
286
+ typename CTYPE_COMMON,
287
+ const char * op_name,
288
+ SupportedTensorDtypes out_dtypes,
289
+ typename Op>
199
290
inline void apply_bitensor_elementwise_fn (
200
291
const Op& compute_fun,
201
292
KernelRuntimeContext& ctx,
@@ -204,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
204
295
const Tensor& b,
205
296
SupportedTensorDtypes b_dtypes,
206
297
const Tensor& out) {
207
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
298
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
208
299
compute_fun,
209
300
ctx,
210
301
out,
211
- out_dtypes,
212
302
std::make_pair (&a, a_dtypes),
213
303
std::make_pair (&b, b_dtypes));
214
304
}
@@ -228,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
228
318
SupportedTensorDtypes c_dtypes,
229
319
const Tensor& out,
230
320
SupportedTensorDtypes out_dtypes) {
231
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
321
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
232
322
compute_fun,
233
323
ctx,
234
324
out,
@@ -258,7 +348,11 @@ inline void apply_tritensor_elementwise_fn(
258
348
* static constexpr const char op_name[] = "my_op";
259
349
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
260
350
*/
261
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
351
+ template <
352
+ typename CTYPE_COMMON,
353
+ const char * op_name,
354
+ SupportedTensorDtypes out_dtypes,
355
+ typename Op>
262
356
inline void apply_tritensor_elementwise_fn (
263
357
const Op& compute_fun,
264
358
KernelRuntimeContext& ctx,
@@ -269,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
269
363
const Tensor& c,
270
364
SupportedTensorDtypes c_dtypes,
271
365
const Tensor& out) {
272
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
366
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
273
367
compute_fun,
274
368
ctx,
275
369
out,
276
- out_dtypes,
277
370
std::make_pair (&a, a_dtypes),
278
371
std::make_pair (&b, b_dtypes),
279
372
std::make_pair (&c, c_dtypes));
0 commit comments