@@ -60,8 +60,43 @@ using op_call_result =
60
60
61
61
template <
62
62
typename CTYPE_COMMON,
63
+ typename CTYPE_OUT,
63
64
typename Op,
64
- typename ... Args>
65
+ typename ... Args>
66
+ inline void dtype_specialized_elementwise_fn_impl (
67
+ const Op& compute_fun,
68
+ KernelRuntimeContext& ctx,
69
+ const Tensor& out,
70
+ Args... inputs) {
71
+ constexpr auto kNumInputs = sizeof ...(inputs);
72
+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
73
+
74
+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75
+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76
+
77
+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78
+
79
+ ::executorch::extension::parallel_for (
80
+ 0 ,
81
+ out.numel(),
82
+ ::executorch::extension::internal::GRAIN_SIZE,
83
+ [&](const auto begin, const auto end) {
84
+ const auto range =
85
+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
86
+ auto begin_it = range.begin ();
87
+ begin_it += begin;
88
+ for (; (*begin_it)[0 ] < end; ++begin_it) {
89
+ const auto & indexes = *begin_it;
90
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
91
+ for (const auto idx : c10::irange (kNumInputs )) {
92
+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
93
+ }
94
+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
95
+ }
96
+ });
97
+ }
98
+
99
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
65
100
inline bool validate_elementwise_fn_inputs (
66
101
const Op& compute_fun,
67
102
KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80
115
ctx,
81
116
(check_input_dtype (inputs, compute_type) && ...) &&
82
117
internal::check_tensor_dtype (out, out_dtypes, compute_type),
83
- InvalidArgument, false );
118
+ InvalidArgument,
119
+ false );
84
120
85
121
return true ;
86
122
}
@@ -90,22 +126,12 @@ template <
90
126
const char * op_name,
91
127
typename Op,
92
128
typename ... Args>
93
- inline void apply_elementwise_fn (
129
+ inline void apply_elementwise_fn_generic_impl (
94
130
const Op& compute_fun,
95
131
KernelRuntimeContext& ctx,
96
132
const Tensor& out,
97
133
SupportedTensorDtypes out_dtypes,
98
134
Args... inputs) {
99
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100
- compute_fun,
101
- ctx,
102
- out,
103
- out_dtypes,
104
- inputs...);
105
- if (!inputs_valid) {
106
- return ;
107
- }
108
-
109
135
constexpr auto kNumInputs = sizeof ...(inputs);
110
136
111
137
struct InputInfo {
@@ -157,6 +183,63 @@ inline void apply_elementwise_fn(
157
183
}
158
184
});
159
185
}
186
+
187
+ template <
188
+ typename CTYPE_COMMON,
189
+ const char * op_name,
190
+ typename Op,
191
+ typename ... Args>
192
+ inline void apply_elementwise_fn_runtime_out_dtypes (
193
+ const Op& compute_fun,
194
+ KernelRuntimeContext& ctx,
195
+ const Tensor& out,
196
+ SupportedTensorDtypes out_dtypes,
197
+ Args... inputs) {
198
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199
+ compute_fun, ctx, out, out_dtypes, inputs...);
200
+ if (!inputs_valid) {
201
+ return ;
202
+ }
203
+
204
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205
+ compute_fun, ctx, out, out_dtypes, inputs...);
206
+ }
207
+
208
+ template <
209
+ typename CTYPE_COMMON,
210
+ const char * op_name,
211
+ SupportedTensorDtypes out_dtypes,
212
+ typename Op,
213
+ typename ... Args>
214
+ inline void apply_elementwise_fn (
215
+ const Op& compute_fun,
216
+ KernelRuntimeContext& ctx,
217
+ const Tensor& out,
218
+ Args... inputs) {
219
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220
+ compute_fun, ctx, out, out_dtypes, inputs...);
221
+ if (!inputs_valid) {
222
+ return ;
223
+ }
224
+
225
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
226
+ const bool all_inputs_compute_dtype =
227
+ ((inputs.first ->scalar_type () == compute_type) && ...);
228
+
229
+ constexpr ScalarType out_specialized_scalar_type =
230
+ specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
231
+ if (all_inputs_compute_dtype &&
232
+ out.scalar_type () == out_specialized_scalar_type) {
233
+ using CTYPE_OUT =
234
+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
235
+ dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
236
+ compute_fun, ctx, out, inputs...);
237
+ return ;
238
+ }
239
+
240
+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
241
+ compute_fun, ctx, out, out_dtypes, inputs...);
242
+ }
160
243
} // namespace internal
161
244
162
245
// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,19 +251,23 @@ inline void apply_unitensor_elementwise_fn(
168
251
SupportedTensorDtypes a_dtypes,
169
252
const Tensor& out,
170
253
SupportedTensorDtypes out_dtypes) {
171
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
254
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
172
255
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
173
256
}
174
257
175
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
258
+ template <
259
+ typename CTYPE_COMMON,
260
+ const char * op_name,
261
+ SupportedTensorDtypes out_dtypes,
262
+ typename Op>
176
263
inline void apply_unitensor_elementwise_fn (
177
264
const Op& compute_fun,
178
265
KernelRuntimeContext& ctx,
179
266
const Tensor& a,
180
267
SupportedTensorDtypes a_dtypes,
181
268
const Tensor& out) {
182
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183
- compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
269
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
270
+ compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
184
271
}
185
272
186
273
/* *
@@ -196,7 +283,7 @@ inline void apply_bitensor_elementwise_fn(
196
283
SupportedTensorDtypes b_dtypes,
197
284
const Tensor& out,
198
285
SupportedTensorDtypes out_dtypes) {
199
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
286
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
200
287
compute_fun,
201
288
ctx,
202
289
out,
@@ -210,7 +297,11 @@ inline void apply_bitensor_elementwise_fn(
210
297
* perform a computation and write to the corresponding element of the output.
211
298
* Tensor broadcasting is applied wherever it is required.
212
299
*/
213
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
300
+ template <
301
+ typename CTYPE_COMMON,
302
+ const char * op_name,
303
+ SupportedTensorDtypes out_dtypes,
304
+ typename Op>
214
305
inline void apply_bitensor_elementwise_fn (
215
306
const Op& compute_fun,
216
307
KernelRuntimeContext& ctx,
@@ -219,11 +310,10 @@ inline void apply_bitensor_elementwise_fn(
219
310
const Tensor& b,
220
311
SupportedTensorDtypes b_dtypes,
221
312
const Tensor& out) {
222
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
313
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
223
314
compute_fun,
224
315
ctx,
225
316
out,
226
- out_dtypes,
227
317
std::make_pair (&a, a_dtypes),
228
318
std::make_pair (&b, b_dtypes));
229
319
}
@@ -243,7 +333,7 @@ inline void apply_tritensor_elementwise_fn(
243
333
SupportedTensorDtypes c_dtypes,
244
334
const Tensor& out,
245
335
SupportedTensorDtypes out_dtypes) {
246
- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
336
+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
247
337
compute_fun,
248
338
ctx,
249
339
out,
@@ -273,7 +363,11 @@ inline void apply_tritensor_elementwise_fn(
273
363
* static constexpr const char op_name[] = "my_op";
274
364
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275
365
*/
276
- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
366
+ template <
367
+ typename CTYPE_COMMON,
368
+ const char * op_name,
369
+ SupportedTensorDtypes out_dtypes,
370
+ typename Op>
277
371
inline void apply_tritensor_elementwise_fn (
278
372
const Op& compute_fun,
279
373
KernelRuntimeContext& ctx,
@@ -284,11 +378,10 @@ inline void apply_tritensor_elementwise_fn(
284
378
const Tensor& c,
285
379
SupportedTensorDtypes c_dtypes,
286
380
const Tensor& out) {
287
- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
381
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
288
382
compute_fun,
289
383
ctx,
290
384
out,
291
- out_dtypes,
292
385
std::make_pair (&a, a_dtypes),
293
386
std::make_pair (&b, b_dtypes),
294
387
std::make_pair (&c, c_dtypes));
0 commit comments