Skip to content

Refactor elementwise_util: create variants with out_dtypes in template argument list #9387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: gh/swolchok/417/head
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 100 additions & 13 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
}

namespace internal {
template <
typename CTYPE_COMPUTE,
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
inline bool validate_elementwise_fn_inputs(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
Expand All @@ -65,7 +61,6 @@ inline void apply_elementwise_fn(
static_assert(
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
...));
constexpr auto kNumInputs = sizeof...(inputs);
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
const auto check_input_dtype = [](auto input, auto compute_type) {
return internal::check_tensor_dtype(
Expand All @@ -75,7 +70,30 @@ inline void apply_elementwise_fn(
ctx,
(check_input_dtype(inputs, compute_type) && ...) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type),
InvalidArgument, );
InvalidArgument,
false);

return true;
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto kNumInputs = sizeof...(inputs);

struct InputInfo {
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
Expand Down Expand Up @@ -120,6 +138,7 @@ inline void apply_elementwise_fn(
});
}

/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
const Op& compute_fun,
Expand All @@ -132,19 +151,83 @@ inline void apply_unitensor_elementwise_fn(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
SupportedTensorDtypes out_dtypes,
typename Op>
inline void apply_unitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

/**
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
*/
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
inline void apply_bitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
}

/**
* Useful for bi-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
*/
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
template <
typename CTYPE_COMPUTE,
const char* op_name,
SupportedTensorDtypes out_dtypes,
typename Op>
inline void apply_bitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
}

/**
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
*/
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
inline void apply_tritensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
Expand All @@ -153,7 +236,8 @@ inline void apply_bitensor_elementwise_fn(
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
std::make_pair(&b, b_dtypes),
std::make_pair(&c, c_dtypes));
}

/**
Expand All @@ -176,7 +260,11 @@ inline void apply_bitensor_elementwise_fn(
* static constexpr const char op_name[] = "my_op";
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
*/
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
template <
typename CTYPE_COMPUTE,
const char* op_name,
SupportedTensorDtypes out_dtypes,
typename Op>
inline void apply_tritensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
Expand All @@ -186,8 +274,7 @@ inline void apply_tritensor_elementwise_fn(
SupportedTensorDtypes b_dtypes,
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
Expand Down
Loading