Skip to content

Commit a10b779

Browse files
committed
RFC: Specialize for non-mixed-dtype in elementwise_util
Mixed dtype should be uncommon. Here is how we can specialize for the common case. Test Plan: automated tests on this PR verify we didn't break the now-deprecated runtime_out_dtypes mode; tests on the next PR will verify that everything works after migration. Also included migration for exactly one operator, op_mul, to verify that the new code compiles. ghstack-source-id: 079c8b97c745f0e4004303ead6ca21de596020cc ghstack-comment-id: 2735017566 Pull Request resolved: #9388
1 parent 90a22ba commit a10b779

File tree

3 files changed

+146
-27
lines changed

3 files changed

+146
-27
lines changed

kernels/portable/cpu/op_mul.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ Tensor& mul_out(
5252
out);
5353

5454
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
55+
utils::apply_bitensor_elementwise_fn<
56+
CTYPE_COMPUTE,
57+
op_name,
58+
utils::SupportedTensorDtypes::REALHBBF16>(
5659
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5760
return val_a * val_b;
5861
},
@@ -61,8 +64,7 @@ Tensor& mul_out(
6164
utils::SupportedTensorDtypes::REALHBBF16,
6265
b,
6366
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
67+
out);
6668
});
6769

6870
return out;

kernels/portable/cpu/util/dtype_util.h

+22
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,28 @@ bool check_tensor_dtype(
324324
SupportedTensorDtypes dtypes,
325325
const ScalarType compute_type);
326326

327+
/// Return the one output type we are willing to emit specialized code
328+
/// to handle, given a compute type of CTYPE_COMMON and supported
329+
/// output types of out_dtypes.
330+
template <typename CTYPE_COMMON>
331+
inline constexpr ScalarType specialized_output_scalar_type(
332+
SupportedTensorDtypes out_dtypes) {
333+
switch (out_dtypes) {
334+
case SupportedTensorDtypes::BOOL:
335+
return ScalarType::Bool;
336+
case SupportedTensorDtypes::BOOL_OR_BYTE:
337+
return ScalarType::Bool;
338+
case SupportedTensorDtypes::REALHBBF16:
339+
case SupportedTensorDtypes::REALHBF16:
340+
case SupportedTensorDtypes::REALH:
341+
case SupportedTensorDtypes::FLOATHBF16:
342+
case SupportedTensorDtypes::INTB:
343+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
344+
case SupportedTensorDtypes::SAME_AS_COMMON:
345+
return CppTypeToScalarType<CTYPE_COMMON>::value;
346+
}
347+
}
348+
327349
} // namespace internal
328350
} // namespace utils
329351
} // namespace native

kernels/portable/cpu/util/elementwise_util.h

+119-24
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,43 @@ using op_call_result =
6060

6161
template <
6262
typename CTYPE_COMMON,
63+
typename CTYPE_OUT,
6364
typename Op,
64-
typename... Args>
65+
typename... Args>
66+
inline void dtype_specialized_elementwise_fn_impl(
67+
const Op& compute_fun,
68+
KernelRuntimeContext& ctx,
69+
const Tensor& out,
70+
Args... inputs) {
71+
constexpr auto kNumInputs = sizeof...(inputs);
72+
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
73+
74+
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
75+
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
76+
77+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
78+
79+
::executorch::extension::parallel_for(
80+
0,
81+
out.numel(),
82+
::executorch::extension::internal::GRAIN_SIZE,
83+
[&](const auto begin, const auto end) {
84+
const auto range =
85+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
86+
auto begin_it = range.begin();
87+
begin_it += begin;
88+
for (; (*begin_it)[0] < end; ++begin_it) {
89+
const auto& indexes = *begin_it;
90+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
91+
for (const auto idx : c10::irange(kNumInputs)) {
92+
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
93+
}
94+
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
95+
}
96+
});
97+
}
98+
99+
template <typename CTYPE_COMMON, typename Op, typename... Args>
65100
inline bool validate_elementwise_fn_inputs(
66101
const Op& compute_fun,
67102
KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80115
ctx,
81116
(check_input_dtype(inputs, compute_type) && ...) &&
82117
internal::check_tensor_dtype(out, out_dtypes, compute_type),
83-
InvalidArgument, false);
118+
InvalidArgument,
119+
false);
84120

85121
return true;
86122
}
@@ -90,22 +126,12 @@ template <
90126
const char* op_name,
91127
typename Op,
92128
typename... Args>
93-
inline void apply_elementwise_fn(
129+
inline void apply_elementwise_fn_generic_impl(
94130
const Op& compute_fun,
95131
KernelRuntimeContext& ctx,
96132
const Tensor& out,
97133
SupportedTensorDtypes out_dtypes,
98134
Args... inputs) {
99-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100-
compute_fun,
101-
ctx,
102-
out,
103-
out_dtypes,
104-
inputs...);
105-
if (!inputs_valid) {
106-
return;
107-
}
108-
109135
constexpr auto kNumInputs = sizeof...(inputs);
110136

111137
struct InputInfo {
@@ -157,6 +183,65 @@ inline void apply_elementwise_fn(
157183
}
158184
});
159185
}
186+
187+
template <
188+
typename CTYPE_COMMON,
189+
const char* op_name,
190+
typename Op,
191+
typename... Args>
192+
inline void apply_elementwise_fn_runtime_out_dtypes(
193+
const Op& compute_fun,
194+
KernelRuntimeContext& ctx,
195+
const Tensor& out,
196+
SupportedTensorDtypes out_dtypes,
197+
Args... inputs) {
198+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199+
compute_fun, ctx, out, out_dtypes, inputs...);
200+
if (!inputs_valid) {
201+
return;
202+
}
203+
204+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205+
compute_fun, ctx, out, out_dtypes, inputs...);
206+
}
207+
208+
template <
209+
typename CTYPE_COMMON,
210+
const char* op_name,
211+
SupportedTensorDtypes out_dtypes,
212+
typename Op,
213+
typename... Args>
214+
inline void apply_elementwise_fn(
215+
const Op& compute_fun,
216+
KernelRuntimeContext& ctx,
217+
const Tensor& out,
218+
Args... inputs) {
219+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220+
compute_fun, ctx, out, out_dtypes, inputs...);
221+
if (!inputs_valid) {
222+
return;
223+
}
224+
225+
constexpr auto kNumInputs = sizeof...(inputs);
226+
227+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
228+
const bool all_inputs_compute_dtype =
229+
((inputs.first->scalar_type() == compute_type) && ...);
230+
231+
constexpr ScalarType out_specialized_scalar_type =
232+
specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
233+
if (all_inputs_compute_dtype &&
234+
out.scalar_type() == out_specialized_scalar_type) {
235+
using CTYPE_OUT =
236+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
237+
dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
238+
compute_fun, ctx, out, inputs...);
239+
return;
240+
}
241+
242+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
243+
compute_fun, ctx, out, out_dtypes, inputs...);
244+
}
160245
} // namespace internal
161246

162247
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,18 +253,22 @@ inline void apply_unitensor_elementwise_fn(
168253
SupportedTensorDtypes a_dtypes,
169254
const Tensor& out,
170255
SupportedTensorDtypes out_dtypes) {
171-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
256+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
172257
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
173258
}
174259

175-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
260+
template <
261+
typename CTYPE_COMMON,
262+
const char* op_name,
263+
SupportedTensorDtypes out_dtypes,
264+
typename Op>
176265
inline void apply_unitensor_elementwise_fn(
177266
const Op& compute_fun,
178267
KernelRuntimeContext& ctx,
179268
const Tensor& a,
180269
SupportedTensorDtypes a_dtypes,
181270
const Tensor& out) {
182-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
271+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
183272
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
184273
}
185274

@@ -196,7 +285,7 @@ inline void apply_bitensor_elementwise_fn(
196285
SupportedTensorDtypes b_dtypes,
197286
const Tensor& out,
198287
SupportedTensorDtypes out_dtypes) {
199-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
288+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
200289
compute_fun,
201290
ctx,
202291
out,
@@ -210,7 +299,11 @@ inline void apply_bitensor_elementwise_fn(
210299
* perform a computation and write to the corresponding element of the output.
211300
* Tensor broadcasting is applied wherever it is required.
212301
*/
213-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
302+
template <
303+
typename CTYPE_COMMON,
304+
const char* op_name,
305+
SupportedTensorDtypes out_dtypes,
306+
typename Op>
214307
inline void apply_bitensor_elementwise_fn(
215308
const Op& compute_fun,
216309
KernelRuntimeContext& ctx,
@@ -219,11 +312,10 @@ inline void apply_bitensor_elementwise_fn(
219312
const Tensor& b,
220313
SupportedTensorDtypes b_dtypes,
221314
const Tensor& out) {
222-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
315+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
223316
compute_fun,
224317
ctx,
225318
out,
226-
out_dtypes,
227319
std::make_pair(&a, a_dtypes),
228320
std::make_pair(&b, b_dtypes));
229321
}
@@ -243,7 +335,7 @@ inline void apply_tritensor_elementwise_fn(
243335
SupportedTensorDtypes c_dtypes,
244336
const Tensor& out,
245337
SupportedTensorDtypes out_dtypes) {
246-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
338+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
247339
compute_fun,
248340
ctx,
249341
out,
@@ -273,7 +365,11 @@ inline void apply_tritensor_elementwise_fn(
273365
* static constexpr const char op_name[] = "my_op";
274366
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275367
*/
276-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
368+
template <
369+
typename CTYPE_COMMON,
370+
const char* op_name,
371+
SupportedTensorDtypes out_dtypes,
372+
typename Op>
277373
inline void apply_tritensor_elementwise_fn(
278374
const Op& compute_fun,
279375
KernelRuntimeContext& ctx,
@@ -284,11 +380,10 @@ inline void apply_tritensor_elementwise_fn(
284380
const Tensor& c,
285381
SupportedTensorDtypes c_dtypes,
286382
const Tensor& out) {
287-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
383+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
288384
compute_fun,
289385
ctx,
290386
out,
291-
out_dtypes,
292387
std::make_pair(&a, a_dtypes),
293388
std::make_pair(&b, b_dtypes),
294389
std::make_pair(&c, c_dtypes));

0 commit comments

Comments
 (0)