From e4d3203ee5dc97cdbe8eb854da7e53fab5b4571b Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:26:18 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"Migrate=20elementwise=5Futil=20caller?= =?UTF-8?q?s=20to=20the=20variants=20with=20out=5Fdtypes=20in=20t=E2=80=A6?= =?UTF-8?q?"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b01c7de65d9e5faf03e69afb466d5da4282d355d. --- kernels/portable/cpu/op_add.cpp | 20 +++++-------- kernels/portable/cpu/op_addmm.cpp | 10 +++---- kernels/portable/cpu/op_atan2.cpp | 10 +++---- kernels/portable/cpu/op_clamp.cpp | 18 ++++-------- kernels/portable/cpu/op_copy.cpp | 20 +++++-------- kernels/portable/cpu/op_div.cpp | 31 ++++++++------------ kernels/portable/cpu/op_elu.cpp | 11 +++---- kernels/portable/cpu/op_floor_divide.cpp | 9 ++---- kernels/portable/cpu/op_fmod.cpp | 18 ++++-------- kernels/portable/cpu/op_maximum.cpp | 8 ++--- kernels/portable/cpu/op_minimum.cpp | 9 ++---- kernels/portable/cpu/op_mul.cpp | 10 +++---- kernels/portable/cpu/op_pow.cpp | 27 ++++++----------- kernels/portable/cpu/op_remainder.cpp | 18 ++++-------- kernels/portable/cpu/op_rsub.cpp | 10 +++---- kernels/portable/cpu/op_sigmoid.cpp | 11 +++---- kernels/portable/cpu/op_sub.cpp | 20 +++++-------- kernels/portable/cpu/op_where.cpp | 14 ++++----- kernels/portable/cpu/pattern/bitwise_op.h | 18 ++++-------- kernels/portable/cpu/pattern/comparison_op.h | 18 ++++-------- kernels/portable/cpu/pattern/logical_op.h | 9 ++---- 21 files changed, 118 insertions(+), 201 deletions(-) diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 555341b3447..adb9d4ea723 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -52,11 +52,8 @@ Tensor& add_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - [val_alpha](const auto val_a, const auto val_b) { + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a + val_alpha * val_b; }, ctx, @@ -64,7 +61,8 @@ Tensor& add_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -102,11 +100,8 @@ Tensor& add_scalar_out( static constexpr const char op_name[] = "add.Scalar_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [b, alpha](const auto val_a) { + utils::apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { CTYPE_COMPUTE val_b = utils::scalar_to(b); CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); return val_a + val_alpha * val_b; @@ -114,7 +109,8 @@ Tensor& add_scalar_out( ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_addmm.cpp b/kernels/portable/cpu/op_addmm.cpp index 440a8b2c0fa..d1df5818cd8 100644 --- a/kernels/portable/cpu/op_addmm.cpp +++ b/kernels/portable/cpu/op_addmm.cpp @@ -88,11 +88,8 @@ Tensor& addmm_out( n, p); - utils::apply_bitensor_elementwise_fn< - CTYPE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( - [alpha_val, beta_val](const auto val_a, const auto val_b) { + utils::apply_bitensor_elementwise_fn( + [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) { return val_a * alpha_val + val_b * beta_val; }, ctx, @@ -100,7 +97,8 @@ Tensor& addmm_out( utils::SupportedTensorDtypes::REALHBF16, in, utils::SupportedTensorDtypes::REALHBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); } }); diff --git a/kernels/portable/cpu/op_atan2.cpp b/kernels/portable/cpu/op_atan2.cpp index 33d66cf2ad7..19267ef49dd 100644 --- a/kernels/portable/cpu/op_atan2.cpp +++ b/kernels/portable/cpu/op_atan2.cpp @@ -55,11 +55,8 @@ Tensor& atan2_out( static constexpr const char op_name[] = "atan2.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto val_a, const auto val_b) { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return std::atan2(val_a, val_b); }, ctx, @@ -67,7 +64,8 @@ Tensor& atan2_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 6974789eccf..c1c40a38f34 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -134,12 +134,8 @@ Tensor& clamp_out( static constexpr const char op_name[] = "clamp.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( + utils::apply_unitensor_elementwise_fn( [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { - // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override( @@ -154,7 +150,8 @@ Tensor& clamp_out( ctx, in, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; @@ -213,15 +210,11 @@ Tensor& clamp_tensor_out( static constexpr const char op_name[] = "clamp.Tensor_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( + utils::apply_tritensor_elementwise_fn( [has_min, has_max]( const CTYPE_COMPUTE val_in, const CTYPE_COMPUTE val_min, const CTYPE_COMPUTE val_max) { - // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override(val_out, val_min); @@ -238,7 +231,8 @@ Tensor& clamp_tensor_out( utils::SupportedTensorDtypes::REALHBBF16, max, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 30fff4d2c10..19b0c3a2f6a 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -47,17 +47,15 @@ Tensor& copy_out( static constexpr const char op_name[] = "copy.out"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - [](ET_UNUSED const auto _, const auto val_src) { return val_src; }, + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, ctx, in, utils::SupportedTensorDtypes::REALHBBF16, src, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -82,17 +80,15 @@ Tensor& copy_( static constexpr const char op_name[] = "copy_"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - [](ET_UNUSED const auto _, const auto val_src) { return val_src; }, + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, ctx, in, utils::SupportedTensorDtypes::REALHBBF16, src, utils::SupportedTensorDtypes::REALHBBF16, - in); + in, + utils::SupportedTensorDtypes::REALHBBF16); }); return in; diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index 70f9479c464..94cd9ea5011 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -58,17 +58,17 @@ Tensor& div_out( static constexpr const char op_name[] = "div.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto val_a, const auto val_b) { return val_a / val_b; }, + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a / val_b; + }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; @@ -122,13 +122,9 @@ Tensor& div_out_mode( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_bitensor_elementwise_fn( [mode_is_trunc, &div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. if (is_integral_type::value) { if (val_b == 0) { div_by_zero_error = true; @@ -150,7 +146,8 @@ Tensor& div_out_mode( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); ET_KERNEL_CHECK_MSG( @@ -191,15 +188,13 @@ Tensor& div_scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [val_b](const auto val_a) { return val_a / val_b; }, + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_elu.cpp b/kernels/portable/cpu/op_elu.cpp index d6533642860..d4846fb1bfb 100644 --- a/kernels/portable/cpu/op_elu.cpp +++ b/kernels/portable/cpu/op_elu.cpp @@ -44,12 +44,8 @@ Tensor& elu_out( ET_EXTRACT_SCALAR(scale, math_scale); ET_EXTRACT_SCALAR(input_scale, math_input_scale); const auto negcoef = math_alpha * math_scale; - utils::apply_unitensor_elementwise_fn< - CTYPE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [negcoef, math_scale, math_input_scale](const auto x) { - // TODO: rewrite this to be vectorization-capable. + utils::apply_unitensor_elementwise_fn( + [negcoef, math_scale, math_input_scale](auto x) { return MathT(x) <= MathT(0) ? std::expm1(MathT(x) * math_input_scale) * negcoef : MathT(x) * math_scale; @@ -57,7 +53,8 @@ Tensor& elu_out( ctx, in, utils::SupportedTensorDtypes::FLOATHBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; } diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 50723c3fa0a..85eb612ea1e 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -53,13 +53,9 @@ Tensor& floor_divide_out( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_bitensor_elementwise_fn( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. if (is_integral_type::value) { if (val_b == 0) { div_by_zero_error = true; @@ -73,7 +69,8 @@ Tensor& floor_divide_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); ET_KERNEL_CHECK_MSG( diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 96a971b166a..1e8cba0f1ae 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -55,13 +55,9 @@ Tensor& fmod_Tensor_out( bool div_by_zero_error = false; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_bitensor_elementwise_fn( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = 0; if (is_integral_type::value) { if (val_b == 0) { @@ -77,7 +73,8 @@ Tensor& fmod_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); ET_KERNEL_CHECK_MSG( @@ -134,19 +131,16 @@ Tensor& fmod_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { - // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = std::fmod(val_a, val_b); return value; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 3a84095a4df..5cf3b5a19f8 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -45,10 +45,7 @@ Tensor& maximum_out( static constexpr const char op_name[] = "maximum.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( + utils::apply_bitensor_elementwise_fn( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return utils::max_override(val_a, val_b); }, @@ -57,7 +54,8 @@ Tensor& maximum_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index 5c0e79eb9bb..e2c641bdb22 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -45,12 +45,8 @@ Tensor& minimum_out( static constexpr const char op_name[] = "minimum.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( + utils::apply_bitensor_elementwise_fn( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. return utils::min_override(val_a, val_b); }, ctx, @@ -58,7 +54,8 @@ Tensor& minimum_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 6156227732d..114e60ff171 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -97,15 +97,13 @@ Tensor& mul_scalar_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [val_b](const auto val_a) { return val_a * val_b; }, + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_pow.cpp b/kernels/portable/cpu/op_pow.cpp index 4d2673cb72d..81319b03d9f 100644 --- a/kernels/portable/cpu/op_pow.cpp +++ b/kernels/portable/cpu/op_pow.cpp @@ -53,12 +53,8 @@ Tensor& pow_Tensor_Tensor_out( static constexpr const char op_name[] = "pow.Tensor_Tensor_out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_bitensor_elementwise_fn( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. return std::pow(val_a, val_b); }, ctx, @@ -66,7 +62,8 @@ Tensor& pow_Tensor_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -107,16 +104,13 @@ Tensor& pow_Tensor_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( - // TODO: rewrite this to be vectorization-capable. + utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -157,16 +151,13 @@ Tensor& pow_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_a = utils::scalar_to(a); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( - // TODO: rewrite this to be vectorization-capable. + utils::apply_unitensor_elementwise_fn( [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); }, ctx, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index 01a5d72de01..d34c34a0380 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -53,13 +53,9 @@ Tensor& remainder_Tensor_out( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_bitensor_elementwise_fn( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = 0; if (is_integral_type::value) { if (val_b == 0) { @@ -75,7 +71,8 @@ Tensor& remainder_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); ET_KERNEL_CHECK_MSG( @@ -129,18 +126,15 @@ Tensor& remainder_Scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( + utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { - // TODO: rewrite this to be vectorization-capable. return utils::remainder_override(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; diff --git a/kernels/portable/cpu/op_rsub.cpp b/kernels/portable/cpu/op_rsub.cpp index 6a0a77b6596..46af021efda 100644 --- a/kernels/portable/cpu/op_rsub.cpp +++ b/kernels/portable/cpu/op_rsub.cpp @@ -52,17 +52,15 @@ Tensor& rsub_scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [val_b, val_alpha](const auto val_a) { + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { return val_b - val_alpha * val_a; }, ctx, a, utils::SupportedTensorDtypes::REALHBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_sigmoid.cpp b/kernels/portable/cpu/op_sigmoid.cpp index acb743a2db6..09cfed524f9 100644 --- a/kernels/portable/cpu/op_sigmoid.cpp +++ b/kernels/portable/cpu/op_sigmoid.cpp @@ -45,12 +45,8 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { static constexpr const char op_name[] = "sigmoid.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto val_in) -> CTYPE_COMPUTE { - // TODO: rewrite this to be vectorization-capable + utils::apply_unitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_in) { CTYPE_COMPUTE out_val = static_cast(1.0) / (static_cast(1.0) + exp(-val_in)); return out_val; @@ -58,7 +54,8 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ctx, in, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index aa90df8dee4..6217f82c3b1 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -56,11 +56,8 @@ Tensor& sub_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBF16>( - [val_alpha](const auto val_a, const auto val_b) { + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a - val_alpha * val_b; }, ctx, @@ -68,7 +65,8 @@ Tensor& sub_out( utils::SupportedTensorDtypes::REALHBF16, b, utils::SupportedTensorDtypes::REALHBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -112,17 +110,15 @@ Tensor& sub_scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [val_b, val_alpha](const auto val_a) { + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { return val_a - val_alpha * val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBF16, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 692e296ee00..b455c45c2d1 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -43,13 +43,10 @@ Tensor& where_out( static constexpr const char op_name[] = "where.self_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [](const auto val_a, const auto val_b, const auto val_c) { - return val_c ? val_a : val_b; - }, + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, @@ -57,7 +54,8 @@ Tensor& where_out( utils::SupportedTensorDtypes::REALHBBF16, cond, utils::SupportedTensorDtypes::BOOL_OR_BYTE, - out); + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/pattern/bitwise_op.h b/kernels/portable/cpu/pattern/bitwise_op.h index f78ce796e6c..6e4c111b8f2 100644 --- a/kernels/portable/cpu/pattern/bitwise_op.h +++ b/kernels/portable/cpu/pattern/bitwise_op.h @@ -80,18 +80,15 @@ Tensor& bitwise_tensor_out( ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - // TODO: rewrite this to be vectorization-capable. + utils::apply_bitensor_elementwise_fn( BitwiseFnForOp::value, ctx, a, utils::SupportedTensorDtypes::INTB, b, utils::SupportedTensorDtypes::INTB, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -124,19 +121,16 @@ Tensor& bitwise_scalar_out( ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( + utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { - // TODO: rewrite this to be vectorization-capable. return BitwiseFnForOp::value( val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::INTB, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/pattern/comparison_op.h b/kernels/portable/cpu/pattern/comparison_op.h index 643d7623922..e0d9bf4dcab 100644 --- a/kernels/portable/cpu/pattern/comparison_op.h +++ b/kernels/portable/cpu/pattern/comparison_op.h @@ -91,18 +91,15 @@ Tensor& comparison_tensor_out( ScalarType compute_type = utils::get_compute_type(common_type); ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - // TODO: rewrite this to be vectorization-capable. + utils::apply_bitensor_elementwise_fn( ComparisonFnForOp::value, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -130,18 +127,15 @@ Tensor& comparison_scalar_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( + utils::apply_unitensor_elementwise_fn( [val_b](const CTYPE_COMPUTE val_a) { - // TODO: rewrite this to be vectorization-capable. return ComparisonFnForOp::value(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/pattern/logical_op.h b/kernels/portable/cpu/pattern/logical_op.h index 4547d3df51b..017822a85a6 100644 --- a/kernels/portable/cpu/pattern/logical_op.h +++ b/kernels/portable/cpu/pattern/logical_op.h @@ -34,18 +34,15 @@ Tensor& logical_tensor_out( InvalidArgument, out); - utils::apply_bitensor_elementwise_fn< - bool, - op_name, - utils::SupportedTensorDtypes::REALHBBF16>( - // TODO: rewrite this to be vectorization-capable. + utils::apply_bitensor_elementwise_fn( fn, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out); + out, + utils::SupportedTensorDtypes::REALHBBF16); return out; }