Skip to content

Commit 031943e

Browse files
committed
RFC: Specialize for non-mixed-dtype in elementwise_util
Mixed dtype should be uncommon. Here is how we can specialize for the common case. Test Plan: automated tests on this PR verify we didn't break the now-deprecated runtime_out_dtypes mode; tests on the next PR will verify that everything works after migration. Also included migration for exactly one operator, op_mul, to verify that the new code compiles. ghstack-source-id: d5471415ee88c4b35a5ac23131774d2585f4bbdf ghstack-comment-id: 2735017566 Pull Request resolved: #9388
1 parent 979e8e9 commit 031943e

File tree

3 files changed

+142
-28
lines changed

3 files changed

+142
-28
lines changed

kernels/portable/cpu/op_mul.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ Tensor& mul_out(
5252
out);
5353

5454
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
55+
utils::apply_bitensor_elementwise_fn<
56+
CTYPE_COMPUTE,
57+
op_name,
58+
utils::SupportedTensorDtypes::REALHBBF16>(
5659
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5760
return val_a * val_b;
5861
},
@@ -61,8 +64,7 @@ Tensor& mul_out(
6164
utils::SupportedTensorDtypes::REALHBBF16,
6265
b,
6366
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
67+
out);
6668
});
6769

6870
return out;

kernels/portable/cpu/util/dtype_util.h

+19
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,25 @@ bool check_tensor_dtype(
285285
SupportedTensorDtypes dtypes,
286286
const ScalarType compute_type);
287287

288+
/// Return the one output type we are willing to emit specialized code
289+
/// to handle, given a compute type of CTYPE_COMMON and supported
290+
/// output types of out_dtypes.
291+
template <typename CTYPE_COMMON>
292+
inline constexpr ScalarType specialized_output_scalar_type(
293+
SupportedTensorDtypes out_dtypes) {
294+
switch (out_dtypes) {
295+
case SupportedTensorDtypes::BOOL_OR_BYTE:
296+
return ScalarType::Bool;
297+
case SupportedTensorDtypes::REALHBBF16:
298+
case SupportedTensorDtypes::REALHBF16:
299+
case SupportedTensorDtypes::FLOATHBF16:
300+
case SupportedTensorDtypes::INTB:
301+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
302+
case SupportedTensorDtypes::SAME_AS_COMMON:
303+
return CppTypeToScalarType<CTYPE_COMMON>::value;
304+
}
305+
}
306+
288307
} // namespace internal
289308
} // namespace utils
290309
} // namespace native

kernels/portable/cpu/util/elementwise_util.h

+118-25
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,43 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5353
namespace internal {
5454
template <
5555
typename CTYPE_COMMON,
56+
typename CTYPE_OUT,
5657
typename Op,
57-
typename... Args>
58+
typename... Args>
59+
inline void dtype_specialized_elementwise_fn_impl(
60+
const Op& compute_fun,
61+
KernelRuntimeContext& ctx,
62+
const Tensor& out,
63+
Args... inputs) {
64+
constexpr auto kNumInputs = sizeof...(inputs);
65+
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
66+
67+
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
68+
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
69+
70+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
71+
72+
::executorch::extension::parallel_for(
73+
0,
74+
out.numel(),
75+
::executorch::extension::internal::GRAIN_SIZE,
76+
[&](const auto begin, const auto end) {
77+
const auto range =
78+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
79+
auto begin_it = range.begin();
80+
begin_it += begin;
81+
for (; (*begin_it)[0] < end; ++begin_it) {
82+
const auto& indexes = *begin_it;
83+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
84+
for (const auto idx : c10::irange(kNumInputs)) {
85+
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
86+
}
87+
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
88+
}
89+
});
90+
}
91+
92+
template <typename CTYPE_COMMON, typename Op, typename... Args>
5893
inline bool validate_elementwise_fn_inputs(
5994
const Op& compute_fun,
6095
KernelRuntimeContext& ctx,
@@ -73,7 +108,8 @@ inline bool validate_elementwise_fn_inputs(
73108
ctx,
74109
(check_input_dtype(inputs, compute_type) && ...) &&
75110
internal::check_tensor_dtype(out, out_dtypes, compute_type),
76-
InvalidArgument, false);
111+
InvalidArgument,
112+
false);
77113

78114
return true;
79115
}
@@ -83,22 +119,12 @@ template <
83119
const char* op_name,
84120
typename Op,
85121
typename... Args>
86-
inline void apply_elementwise_fn(
122+
inline void apply_elementwise_fn_generic_impl(
87123
const Op& compute_fun,
88124
KernelRuntimeContext& ctx,
89125
const Tensor& out,
90126
SupportedTensorDtypes out_dtypes,
91127
Args... inputs) {
92-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
93-
compute_fun,
94-
ctx,
95-
out,
96-
out_dtypes,
97-
inputs...);
98-
if (!inputs_valid) {
99-
return;
100-
}
101-
102128
constexpr auto kNumInputs = sizeof...(inputs);
103129

104130
struct InputInfo {
@@ -142,6 +168,63 @@ inline void apply_elementwise_fn(
142168
}
143169
});
144170
}
171+
172+
template <
173+
typename CTYPE_COMMON,
174+
const char* op_name,
175+
typename Op,
176+
typename... Args>
177+
inline void apply_elementwise_fn_runtime_out_dtypes(
178+
const Op& compute_fun,
179+
KernelRuntimeContext& ctx,
180+
const Tensor& out,
181+
SupportedTensorDtypes out_dtypes,
182+
Args... inputs) {
183+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
184+
compute_fun, ctx, out, out_dtypes, inputs...);
185+
if (!inputs_valid) {
186+
return;
187+
}
188+
189+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
190+
compute_fun, ctx, out, out_dtypes, inputs...);
191+
}
192+
193+
template <
194+
typename CTYPE_COMMON,
195+
const char* op_name,
196+
SupportedTensorDtypes out_dtypes,
197+
typename Op,
198+
typename... Args>
199+
inline void apply_elementwise_fn(
200+
const Op& compute_fun,
201+
KernelRuntimeContext& ctx,
202+
const Tensor& out,
203+
Args... inputs) {
204+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
205+
compute_fun, ctx, out, out_dtypes, inputs...);
206+
if (!inputs_valid) {
207+
return;
208+
}
209+
210+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
211+
const bool all_inputs_compute_dtype =
212+
((inputs.first->scalar_type() == compute_type) && ...);
213+
214+
constexpr ScalarType out_specialized_scalar_type =
215+
specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
216+
if (all_inputs_compute_dtype &&
217+
out.scalar_type() == out_specialized_scalar_type) {
218+
using CTYPE_OUT =
219+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
220+
dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
221+
compute_fun, ctx, out, inputs...);
222+
return;
223+
}
224+
225+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
226+
compute_fun, ctx, out, out_dtypes, inputs...);
227+
}
145228
} // namespace internal
146229

147230
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -153,19 +236,23 @@ inline void apply_unitensor_elementwise_fn(
153236
SupportedTensorDtypes a_dtypes,
154237
const Tensor& out,
155238
SupportedTensorDtypes out_dtypes) {
156-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
239+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
157240
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
158241
}
159242

160-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
243+
template <
244+
typename CTYPE_COMMON,
245+
const char* op_name,
246+
SupportedTensorDtypes out_dtypes,
247+
typename Op>
161248
inline void apply_unitensor_elementwise_fn(
162249
const Op& compute_fun,
163250
KernelRuntimeContext& ctx,
164251
const Tensor& a,
165252
SupportedTensorDtypes a_dtypes,
166253
const Tensor& out) {
167-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
168-
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
254+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
255+
compute_fun, ctx, out, std::make_pair(&a, a_dtypes));
169256
}
170257

171258
/**
@@ -181,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
181268
SupportedTensorDtypes b_dtypes,
182269
const Tensor& out,
183270
SupportedTensorDtypes out_dtypes) {
184-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
271+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
185272
compute_fun,
186273
ctx,
187274
out,
@@ -195,7 +282,11 @@ inline void apply_bitensor_elementwise_fn(
195282
* perform a computation and write to the corresponding element of the output.
196283
* Tensor broadcasting is applied wherever it is required.
197284
*/
198-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
285+
template <
286+
typename CTYPE_COMMON,
287+
const char* op_name,
288+
SupportedTensorDtypes out_dtypes,
289+
typename Op>
199290
inline void apply_bitensor_elementwise_fn(
200291
const Op& compute_fun,
201292
KernelRuntimeContext& ctx,
@@ -204,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
204295
const Tensor& b,
205296
SupportedTensorDtypes b_dtypes,
206297
const Tensor& out) {
207-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
298+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
208299
compute_fun,
209300
ctx,
210301
out,
211-
out_dtypes,
212302
std::make_pair(&a, a_dtypes),
213303
std::make_pair(&b, b_dtypes));
214304
}
@@ -228,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
228318
SupportedTensorDtypes c_dtypes,
229319
const Tensor& out,
230320
SupportedTensorDtypes out_dtypes) {
231-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
321+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
232322
compute_fun,
233323
ctx,
234324
out,
@@ -258,7 +348,11 @@ inline void apply_tritensor_elementwise_fn(
258348
* static constexpr const char op_name[] = "my_op";
259349
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
260350
*/
261-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
351+
template <
352+
typename CTYPE_COMMON,
353+
const char* op_name,
354+
SupportedTensorDtypes out_dtypes,
355+
typename Op>
262356
inline void apply_tritensor_elementwise_fn(
263357
const Op& compute_fun,
264358
KernelRuntimeContext& ctx,
@@ -269,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
269363
const Tensor& c,
270364
SupportedTensorDtypes c_dtypes,
271365
const Tensor& out) {
272-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
366+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
273367
compute_fun,
274368
ctx,
275369
out,
276-
out_dtypes,
277370
std::make_pair(&a, a_dtypes),
278371
std::make_pair(&b, b_dtypes),
279372
std::make_pair(&c, c_dtypes));

0 commit comments

Comments
 (0)