Skip to content

Commit 3b3979e

Browse files
committed
RFC: Specialize for non-mixed-dtype in elementwise_util
Mixed dtype should be uncommon. Here is how we can specialize for the common case. Test Plan: automated tests on this PR verify we didn't break the now-deprecated runtime_out_dtypes mode; tests on the next PR will verify that everything works after migration. Also included migration for exactly one operator, op_mul, to verify that the new code compiles. ghstack-source-id: 5a85b2aec29a6a72aa8195a05fb43bbebcff38d8 ghstack-comment-id: 2735017566 Pull Request resolved: #9388
1 parent 68cee3e commit 3b3979e

File tree

3 files changed

+145
-28
lines changed

3 files changed

+145
-28
lines changed

kernels/portable/cpu/op_mul.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ Tensor& mul_out(
5252
out);
5353

5454
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
55+
utils::apply_bitensor_elementwise_fn<
56+
CTYPE_COMPUTE,
57+
op_name,
58+
utils::SupportedTensorDtypes::REALHBBF16>(
5659
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5760
return val_a * val_b;
5861
},
@@ -61,8 +64,7 @@ Tensor& mul_out(
6164
utils::SupportedTensorDtypes::REALHBBF16,
6265
b,
6366
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
67+
out);
6668
});
6769

6870
return out;

kernels/portable/cpu/util/dtype_util.h

+22
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,28 @@ bool check_tensor_dtype(
309309
SupportedTensorDtypes dtypes,
310310
const ScalarType compute_type);
311311

312+
/// Return the one output type we are willing to emit specialized code
313+
/// to handle, given a compute type of CTYPE_COMMON and supported
314+
/// output types of out_dtypes.
315+
template <typename CTYPE_COMMON>
316+
inline constexpr ScalarType specialized_output_scalar_type(
317+
SupportedTensorDtypes out_dtypes) {
318+
switch (out_dtypes) {
319+
case SupportedTensorDtypes::BOOL:
320+
return ScalarType::Bool;
321+
case SupportedTensorDtypes::BOOL_OR_BYTE:
322+
return ScalarType::Bool;
323+
case SupportedTensorDtypes::REALHBBF16:
324+
case SupportedTensorDtypes::REALHBF16:
325+
case SupportedTensorDtypes::REALH:
326+
case SupportedTensorDtypes::FLOATHBF16:
327+
case SupportedTensorDtypes::INTB:
328+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
329+
case SupportedTensorDtypes::SAME_AS_COMMON:
330+
return CppTypeToScalarType<CTYPE_COMMON>::value;
331+
}
332+
}
333+
312334
} // namespace internal
313335
} // namespace utils
314336
} // namespace native

kernels/portable/cpu/util/elementwise_util.h

+118-25
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,43 @@ using op_call_result =
6060

6161
template <
6262
typename CTYPE_COMMON,
63+
typename CTYPE_OUT,
6364
typename Op,
64-
typename... Args>
65+
typename... Args>
66+
inline void dtype_specialized_elementwise_fn_impl(
67+
const Op& compute_fun,
68+
KernelRuntimeContext& ctx,
69+
const Tensor& out,
70+
Args... inputs) {
71+
constexpr auto kNumInputs = sizeof...(inputs);
72+
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
73+
74+
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
75+
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
76+
77+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
78+
79+
::executorch::extension::parallel_for(
80+
0,
81+
out.numel(),
82+
::executorch::extension::internal::GRAIN_SIZE,
83+
[&](const auto begin, const auto end) {
84+
const auto range =
85+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
86+
auto begin_it = range.begin();
87+
begin_it += begin;
88+
for (; (*begin_it)[0] < end; ++begin_it) {
89+
const auto& indexes = *begin_it;
90+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
91+
for (const auto idx : c10::irange(kNumInputs)) {
92+
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
93+
}
94+
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
95+
}
96+
});
97+
}
98+
99+
template <typename CTYPE_COMMON, typename Op, typename... Args>
65100
inline bool validate_elementwise_fn_inputs(
66101
const Op& compute_fun,
67102
KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80115
ctx,
81116
(check_input_dtype(inputs, compute_type) && ...) &&
82117
internal::check_tensor_dtype(out, out_dtypes, compute_type),
83-
InvalidArgument, false);
118+
InvalidArgument,
119+
false);
84120

85121
return true;
86122
}
@@ -90,22 +126,12 @@ template <
90126
const char* op_name,
91127
typename Op,
92128
typename... Args>
93-
inline void apply_elementwise_fn(
129+
inline void apply_elementwise_fn_generic_impl(
94130
const Op& compute_fun,
95131
KernelRuntimeContext& ctx,
96132
const Tensor& out,
97133
SupportedTensorDtypes out_dtypes,
98134
Args... inputs) {
99-
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100-
compute_fun,
101-
ctx,
102-
out,
103-
out_dtypes,
104-
inputs...);
105-
if (!inputs_valid) {
106-
return;
107-
}
108-
109135
constexpr auto kNumInputs = sizeof...(inputs);
110136

111137
struct InputInfo {
@@ -157,6 +183,63 @@ inline void apply_elementwise_fn(
157183
}
158184
});
159185
}
186+
187+
template <
188+
typename CTYPE_COMMON,
189+
const char* op_name,
190+
typename Op,
191+
typename... Args>
192+
inline void apply_elementwise_fn_runtime_out_dtypes(
193+
const Op& compute_fun,
194+
KernelRuntimeContext& ctx,
195+
const Tensor& out,
196+
SupportedTensorDtypes out_dtypes,
197+
Args... inputs) {
198+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199+
compute_fun, ctx, out, out_dtypes, inputs...);
200+
if (!inputs_valid) {
201+
return;
202+
}
203+
204+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205+
compute_fun, ctx, out, out_dtypes, inputs...);
206+
}
207+
208+
template <
209+
typename CTYPE_COMMON,
210+
const char* op_name,
211+
SupportedTensorDtypes out_dtypes,
212+
typename Op,
213+
typename... Args>
214+
inline void apply_elementwise_fn(
215+
const Op& compute_fun,
216+
KernelRuntimeContext& ctx,
217+
const Tensor& out,
218+
Args... inputs) {
219+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220+
compute_fun, ctx, out, out_dtypes, inputs...);
221+
if (!inputs_valid) {
222+
return;
223+
}
224+
225+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
226+
const bool all_inputs_compute_dtype =
227+
((inputs.first->scalar_type() == compute_type) && ...);
228+
229+
constexpr ScalarType out_specialized_scalar_type =
230+
specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
231+
if (all_inputs_compute_dtype &&
232+
out.scalar_type() == out_specialized_scalar_type) {
233+
using CTYPE_OUT =
234+
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
235+
dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
236+
compute_fun, ctx, out, inputs...);
237+
return;
238+
}
239+
240+
apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
241+
compute_fun, ctx, out, out_dtypes, inputs...);
242+
}
160243
} // namespace internal
161244

162245
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,19 +251,23 @@ inline void apply_unitensor_elementwise_fn(
168251
SupportedTensorDtypes a_dtypes,
169252
const Tensor& out,
170253
SupportedTensorDtypes out_dtypes) {
171-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
254+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
172255
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
173256
}
174257

175-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
258+
template <
259+
typename CTYPE_COMMON,
260+
const char* op_name,
261+
SupportedTensorDtypes out_dtypes,
262+
typename Op>
176263
inline void apply_unitensor_elementwise_fn(
177264
const Op& compute_fun,
178265
KernelRuntimeContext& ctx,
179266
const Tensor& a,
180267
SupportedTensorDtypes a_dtypes,
181268
const Tensor& out) {
182-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183-
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
269+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
270+
compute_fun, ctx, out, std::make_pair(&a, a_dtypes));
184271
}
185272

186273
/**
@@ -196,7 +283,7 @@ inline void apply_bitensor_elementwise_fn(
196283
SupportedTensorDtypes b_dtypes,
197284
const Tensor& out,
198285
SupportedTensorDtypes out_dtypes) {
199-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
286+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
200287
compute_fun,
201288
ctx,
202289
out,
@@ -210,7 +297,11 @@ inline void apply_bitensor_elementwise_fn(
210297
* perform a computation and write to the corresponding element of the output.
211298
* Tensor broadcasting is applied wherever it is required.
212299
*/
213-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
300+
template <
301+
typename CTYPE_COMMON,
302+
const char* op_name,
303+
SupportedTensorDtypes out_dtypes,
304+
typename Op>
214305
inline void apply_bitensor_elementwise_fn(
215306
const Op& compute_fun,
216307
KernelRuntimeContext& ctx,
@@ -219,11 +310,10 @@ inline void apply_bitensor_elementwise_fn(
219310
const Tensor& b,
220311
SupportedTensorDtypes b_dtypes,
221312
const Tensor& out) {
222-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
313+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
223314
compute_fun,
224315
ctx,
225316
out,
226-
out_dtypes,
227317
std::make_pair(&a, a_dtypes),
228318
std::make_pair(&b, b_dtypes));
229319
}
@@ -243,7 +333,7 @@ inline void apply_tritensor_elementwise_fn(
243333
SupportedTensorDtypes c_dtypes,
244334
const Tensor& out,
245335
SupportedTensorDtypes out_dtypes) {
246-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
336+
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMMON, op_name>(
247337
compute_fun,
248338
ctx,
249339
out,
@@ -273,7 +363,11 @@ inline void apply_tritensor_elementwise_fn(
273363
* static constexpr const char op_name[] = "my_op";
274364
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275365
*/
276-
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
366+
template <
367+
typename CTYPE_COMMON,
368+
const char* op_name,
369+
SupportedTensorDtypes out_dtypes,
370+
typename Op>
277371
inline void apply_tritensor_elementwise_fn(
278372
const Op& compute_fun,
279373
KernelRuntimeContext& ctx,
@@ -284,11 +378,10 @@ inline void apply_tritensor_elementwise_fn(
284378
const Tensor& c,
285379
SupportedTensorDtypes c_dtypes,
286380
const Tensor& out) {
287-
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
381+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes>(
288382
compute_fun,
289383
ctx,
290384
out,
291-
out_dtypes,
292385
std::make_pair(&a, a_dtypes),
293386
std::make_pair(&b, b_dtypes),
294387
std::make_pair(&c, c_dtypes));

0 commit comments

Comments
 (0)