Skip to content

RFC: Specialize for non-mixed-dtype in elementwise_util #9388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: gh/swolchok/382/head
Choose a base branch
from
8 changes: 5 additions & 3 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ Tensor& mul_out(
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
Expand All @@ -61,8 +64,7 @@ Tensor& mul_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down
19 changes: 19 additions & 0 deletions kernels/portable/cpu/util/dtype_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,25 @@ bool check_tensor_dtype(
SupportedTensorDtypes dtypes,
const ScalarType compute_type);

/// Return the one output type we are willing to emit specialized code
/// to handle, given a compute type of CTYPE_COMMON and supported
/// output types of out_dtypes.
template <typename CTYPE_COMMON>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be CTYPE_COMPUTE, right?

inline constexpr ScalarType specialized_output_scalar_type(
SupportedTensorDtypes out_dtypes) {
switch (out_dtypes) {
case SupportedTensorDtypes::BOOL_OR_BYTE:
return ScalarType::Bool;
case SupportedTensorDtypes::REALHBBF16:
case SupportedTensorDtypes::REALHBF16:
case SupportedTensorDtypes::FLOATHBF16:
case SupportedTensorDtypes::INTB:
case SupportedTensorDtypes::SAME_AS_COMPUTE:
case SupportedTensorDtypes::SAME_AS_COMMON:
return CppTypeToScalarType<CTYPE_COMMON>::value;
}
}

} // namespace internal
} // namespace utils
} // namespace native
Expand Down
119 changes: 103 additions & 16 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
}

namespace internal {
template <
typename CTYPE_COMPUTE,
typename CTYPE_OUT,
typename Op,
typename... Args>
inline void dtype_specialized_elementwise_fn_impl(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
constexpr auto kNumInputs = sizeof...(inputs);
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));

::executorch::extension::parallel_for(
0,
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

const auto range =
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto& indexes = *begin_it;
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
}
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
}
});
}

template <typename CTYPE_COMPUTE, typename Op, typename... Args>
inline bool validate_elementwise_fn_inputs(
const Op& compute_fun,
Expand Down Expand Up @@ -81,18 +119,12 @@ template <
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
inline void apply_elementwise_fn_generic_impl(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto kNumInputs = sizeof...(inputs);

struct InputInfo {
Expand Down Expand Up @@ -138,6 +170,63 @@ inline void apply_elementwise_fn(
});
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn_runtime_out_dtypes(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, inputs...);
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
SupportedTensorDtypes out_dtypes,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
const bool all_inputs_compute_dtype =
((inputs.first->scalar_type() == compute_type) && ...);

constexpr ScalarType out_specialized_scalar_type =
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, CTYPE_COMPUTE

if (all_inputs_compute_dtype &&
out.scalar_type() == out_specialized_scalar_type) {
using CTYPE_OUT =
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>(
compute_fun, ctx, out, inputs...);
return;
}

apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, inputs...);
}

/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
Expand All @@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn(
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

Expand All @@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn(
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun, ctx, out, std::make_pair(&a, a_dtypes));
}

/**
Expand All @@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
Expand All @@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
}
Expand All @@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
Expand Down Expand Up @@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes),
std::make_pair(&c, c_dtypes));
Expand Down
Loading