@@ -60,8 +60,43 @@ using op_call_result =
60
60
61
61
template <
62
62
typename CTYPE_COMMON,
63
+ typename CTYPE_OUT,
63
64
typename Op,
64
- typename ... Args>
65
+ typename ... Args>
66
+ inline void dtype_specialized_elementwise_fn_impl (
67
+ const Op& compute_fun,
68
+ KernelRuntimeContext& ctx,
69
+ const Tensor& out,
70
+ Args... inputs) {
71
+ constexpr auto kNumInputs = sizeof ...(inputs);
72
+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
73
+
74
+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75
+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76
+
77
+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78
+
79
+ ::executorch::extension::parallel_for (
80
+ 0 ,
81
+ out.numel(),
82
+ ::executorch::extension::internal::GRAIN_SIZE,
83
+ [&](const auto begin, const auto end) {
84
+ const auto range =
85
+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
86
+ auto begin_it = range.begin ();
87
+ begin_it += begin;
88
+ for (; (*begin_it)[0 ] < end; ++begin_it) {
89
+ const auto & indexes = *begin_it;
90
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
91
+ for (const auto idx : c10::irange (kNumInputs )) {
92
+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
93
+ }
94
+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
95
+ }
96
+ });
97
+ }
98
+
99
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
65
100
inline bool validate_elementwise_fn_inputs (
66
101
const Op& compute_fun,
67
102
KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80
115
ctx,
81
116
(check_input_dtype (inputs, compute_type) && ...) &&
82
117
internal::check_tensor_dtype (out, out_dtypes, compute_type),
83
- InvalidArgument, false );
118
+ InvalidArgument,
119
+ false );
84
120
85
121
return true ;
86
122
}
@@ -90,22 +126,12 @@ template <
90
126
const char * op_name,
91
127
typename Op,
92
128
typename ... Args>
93
- inline void apply_elementwise_fn (
129
+ inline void apply_elementwise_fn_generic_impl (
94
130
const Op& compute_fun,
95
131
KernelRuntimeContext& ctx,
96
132
const Tensor& out,
97
133
SupportedTensorDtypes out_dtypes,
98
134
Args... inputs) {
99
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100
- compute_fun,
101
- ctx,
102
- out,
103
- out_dtypes,
104
- inputs...);
105
- if (!inputs_valid) {
106
- return ;
107
- }
108
-
109
135
constexpr auto kNumInputs = sizeof ...(inputs);
110
136
111
137
struct InputInfo {
@@ -157,6 +183,65 @@ inline void apply_elementwise_fn(
157
183
}
158
184
});
159
185
}
186
+
187
+ template <
188
+ typename CTYPE_COMMON,
189
+ const char * op_name,
190
+ typename Op,
191
+ typename ... Args>
192
+ inline void apply_elementwise_fn_runtime_out_dtypes (
193
+ const Op& compute_fun,
194
+ KernelRuntimeContext& ctx,
195
+ const Tensor& out,
196
+ SupportedTensorDtypes out_dtypes,
197
+ Args... inputs) {
198
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199
+ compute_fun, ctx, out, out_dtypes, inputs...);
200
+ if (!inputs_valid) {
201
+ return ;
202
+ }
203
+
204
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205
+ compute_fun, ctx, out, out_dtypes, inputs...);
206
+ }
207
+
208
+ template <
209
+ typename CTYPE_COMMON,
210
+ const char * op_name,
211
+ SupportedTensorDtypes out_dtypes,
212
+ typename Op,
213
+ typename ... Args>
214
+ inline void apply_elementwise_fn (
215
+ const Op& compute_fun,
216
+ KernelRuntimeContext& ctx,
217
+ const Tensor& out,
218
+ Args... inputs) {
219
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220
+ compute_fun, ctx, out, out_dtypes, inputs...);
221
+ if (!inputs_valid) {
222
+ return ;
223
+ }
224
+
225
+ constexpr auto kNumInputs = sizeof ...(inputs);
226
+
227
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
228
+ const bool all_inputs_compute_dtype =
229
+ ((inputs.first ->scalar_type () == compute_type) && ...);
230
+
231
+ constexpr ScalarType out_specialized_scalar_type =
232
+ specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
233
+ if (all_inputs_compute_dtype &&
234
+ out.scalar_type () == out_specialized_scalar_type) {
235
+ using CTYPE_OUT =
236
+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
237
+ dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
238
+ compute_fun, ctx, out, inputs...);
239
+ return ;
240
+ }
241
+
242
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
243
+ compute_fun, ctx, out, out_dtypes, inputs...);
244
+ }
160
245
} // namespace internal
161
246
162
247
// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,18 +253,22 @@ inline void apply_unitensor_elementwise_fn(
168
253
SupportedTensorDtypes a_dtypes,
169
254
const Tensor& out,
170
255
SupportedTensorDtypes out_dtypes) {
171
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
256
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
172
257
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
173
258
}
174
259
175
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
260
+ template <
261
+ typename CTYPE_COMMON,
262
+ const char * op_name,
263
+ SupportedTensorDtypes out_dtypes,
264
+ typename Op>
176
265
inline void apply_unitensor_elementwise_fn (
177
266
const Op& compute_fun,
178
267
KernelRuntimeContext& ctx,
179
268
const Tensor& a,
180
269
SupportedTensorDtypes a_dtypes,
181
270
const Tensor& out) {
182
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
271
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
183
272
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
184
273
}
185
274
@@ -196,7 +285,7 @@ inline void apply_bitensor_elementwise_fn(
196
285
SupportedTensorDtypes b_dtypes,
197
286
const Tensor& out,
198
287
SupportedTensorDtypes out_dtypes) {
199
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
288
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
200
289
compute_fun,
201
290
ctx,
202
291
out,
@@ -210,7 +299,11 @@ inline void apply_bitensor_elementwise_fn(
210
299
* perform a computation and write to the corresponding element of the output.
211
300
* Tensor broadcasting is applied wherever it is required.
212
301
*/
213
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
302
+ template <
303
+ typename CTYPE_COMMON,
304
+ const char * op_name,
305
+ SupportedTensorDtypes out_dtypes,
306
+ typename Op>
214
307
inline void apply_bitensor_elementwise_fn (
215
308
const Op& compute_fun,
216
309
KernelRuntimeContext& ctx,
@@ -219,11 +312,10 @@ inline void apply_bitensor_elementwise_fn(
219
312
const Tensor& b,
220
313
SupportedTensorDtypes b_dtypes,
221
314
const Tensor& out) {
222
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
315
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
223
316
compute_fun,
224
317
ctx,
225
318
out,
226
- out_dtypes,
227
319
std::make_pair (&a, a_dtypes),
228
320
std::make_pair (&b, b_dtypes));
229
321
}
@@ -243,7 +335,7 @@ inline void apply_tritensor_elementwise_fn(
243
335
SupportedTensorDtypes c_dtypes,
244
336
const Tensor& out,
245
337
SupportedTensorDtypes out_dtypes) {
246
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
338
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
247
339
compute_fun,
248
340
ctx,
249
341
out,
@@ -273,7 +365,11 @@ inline void apply_tritensor_elementwise_fn(
273
365
* static constexpr const char op_name[] = "my_op";
274
366
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275
367
*/
276
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
368
+ template <
369
+ typename CTYPE_COMMON,
370
+ const char * op_name,
371
+ SupportedTensorDtypes out_dtypes,
372
+ typename Op>
277
373
inline void apply_tritensor_elementwise_fn (
278
374
const Op& compute_fun,
279
375
KernelRuntimeContext& ctx,
@@ -284,11 +380,10 @@ inline void apply_tritensor_elementwise_fn(
284
380
const Tensor& c,
285
381
SupportedTensorDtypes c_dtypes,
286
382
const Tensor& out) {
287
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
383
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
288
384
compute_fun,
289
385
ctx,
290
386
out,
291
- out_dtypes,
292
387
std::make_pair (&a, a_dtypes),
293
388
std::make_pair (&b, b_dtypes),
294
389
std::make_pair (&c, c_dtypes));
0 commit comments