diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index adb9d4ea723..555341b3447 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -52,8 +52,11 @@ Tensor& add_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + [val_alpha](const auto val_a, const auto val_b) { return val_a + val_alpha * val_b; }, ctx, @@ -61,8 +64,7 @@ Tensor& add_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; @@ -100,8 +102,11 @@ Tensor& add_scalar_out( static constexpr const char op_name[] = "add.Scalar_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [b, alpha](const CTYPE_COMPUTE val_a) { + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [b, alpha](const auto val_a) { CTYPE_COMPUTE val_b = utils::scalar_to(b); CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); return val_a + val_alpha * val_b; @@ -109,8 +114,7 @@ Tensor& add_scalar_out( ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/op_addmm.cpp b/kernels/portable/cpu/op_addmm.cpp index d1df5818cd8..440a8b2c0fa 100644 --- a/kernels/portable/cpu/op_addmm.cpp +++ b/kernels/portable/cpu/op_addmm.cpp @@ -88,8 +88,11 @@ Tensor& addmm_out( n, p); - utils::apply_bitensor_elementwise_fn( - [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) { + utils::apply_bitensor_elementwise_fn< + CTYPE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( + [alpha_val, beta_val](const auto val_a, const auto val_b) { return val_a * alpha_val + val_b * beta_val; }, ctx, @@ -97,8 +100,7 @@ Tensor& addmm_out( utils::SupportedTensorDtypes::REALHBF16, in, utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); } }); diff --git a/kernels/portable/cpu/op_atan2.cpp b/kernels/portable/cpu/op_atan2.cpp index 19267ef49dd..33d66cf2ad7 100644 --- a/kernels/portable/cpu/op_atan2.cpp +++ b/kernels/portable/cpu/op_atan2.cpp @@ -55,8 +55,11 @@ Tensor& atan2_out( static constexpr const char op_name[] = "atan2.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_a, const auto val_b) { return std::atan2(val_a, val_b); }, ctx, @@ -64,8 +67,7 @@ Tensor& atan2_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::FLOATHBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index c1c40a38f34..6974789eccf 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -134,8 +134,12 @@ Tensor& clamp_out( static constexpr const char op_name[] = "clamp.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { + // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override( @@ -150,8 +154,7 @@ Tensor& clamp_out( ctx, in, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; @@ -210,11 +213,15 @@ Tensor& clamp_tensor_out( static constexpr const char op_name[] = "clamp.Tensor_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn( + utils::apply_tritensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [has_min, has_max]( const CTYPE_COMPUTE val_in, const CTYPE_COMPUTE val_min, const CTYPE_COMPUTE val_max) { + // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE val_out = val_in; if (has_min) { val_out = utils::max_override(val_out, val_min); @@ -231,8 +238,7 @@ Tensor& clamp_tensor_out( utils::SupportedTensorDtypes::REALHBBF16, max, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 19b0c3a2f6a..30fff4d2c10 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -47,15 +47,17 @@ Tensor& copy_out( static constexpr const char op_name[] = "copy.out"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + utils::apply_bitensor_elementwise_fn< + CTYPE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + [](ET_UNUSED const auto _, const auto val_src) { return val_src; }, ctx, in, utils::SupportedTensorDtypes::REALHBBF16, src, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; @@ -80,15 +82,17 @@ Tensor& copy_( static constexpr const char op_name[] = "copy_"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + utils::apply_bitensor_elementwise_fn< + CTYPE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + [](ET_UNUSED const auto _, const auto val_src) { return val_src; }, ctx, in, utils::SupportedTensorDtypes::REALHBBF16, src, utils::SupportedTensorDtypes::REALHBBF16, - in, - utils::SupportedTensorDtypes::REALHBBF16); + in); }); return in; diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index 94cd9ea5011..70f9479c464 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -58,17 +58,17 @@ Tensor& div_out( static constexpr const char op_name[] = "div.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a / val_b; - }, + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_a, const auto val_b) { return val_a / val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::FLOATHBF16); + out); }); return out; @@ -122,9 +122,13 @@ Tensor& div_out_mode( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [mode_is_trunc, &div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. if (is_integral_type::value) { if (val_b == 0) { div_by_zero_error = true; @@ -146,8 +150,7 @@ Tensor& div_out_mode( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); ET_KERNEL_CHECK_MSG( @@ -188,13 +191,15 @@ Tensor& div_scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b](const auto val_a) { return val_a / val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/op_elu.cpp b/kernels/portable/cpu/op_elu.cpp index d4846fb1bfb..d6533642860 100644 --- a/kernels/portable/cpu/op_elu.cpp +++ b/kernels/portable/cpu/op_elu.cpp @@ -44,8 +44,12 @@ Tensor& elu_out( ET_EXTRACT_SCALAR(scale, math_scale); ET_EXTRACT_SCALAR(input_scale, math_input_scale); const auto negcoef = math_alpha * math_scale; - utils::apply_unitensor_elementwise_fn( - [negcoef, math_scale, math_input_scale](auto x) { + utils::apply_unitensor_elementwise_fn< + CTYPE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [negcoef, math_scale, math_input_scale](const auto x) { + // TODO: rewrite this to be vectorization-capable. return MathT(x) <= MathT(0) ? std::expm1(MathT(x) * math_input_scale) * negcoef : MathT(x) * math_scale; @@ -53,8 +57,7 @@ Tensor& elu_out( ctx, in, utils::SupportedTensorDtypes::FLOATHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; } diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 85eb612ea1e..50723c3fa0a 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -53,9 +53,13 @@ Tensor& floor_divide_out( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. if (is_integral_type::value) { if (val_b == 0) { div_by_zero_error = true; @@ -69,8 +73,7 @@ Tensor& floor_divide_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); ET_KERNEL_CHECK_MSG( diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 1e8cba0f1ae..96a971b166a 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -55,9 +55,13 @@ Tensor& fmod_Tensor_out( bool div_by_zero_error = false; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = 0; if (is_integral_type::value) { if (val_b == 0) { @@ -73,8 +77,7 @@ Tensor& fmod_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); ET_KERNEL_CHECK_MSG( @@ -131,16 +134,19 @@ Tensor& fmod_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [val_b](const CTYPE_COMPUTE val_a) { + // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = std::fmod(val_a, val_b); return value; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 5cf3b5a19f8..3a84095a4df 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -45,7 +45,10 @@ Tensor& maximum_out( static constexpr const char op_name[] = "maximum.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return utils::max_override(val_a, val_b); }, @@ -54,8 +57,7 @@ Tensor& maximum_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index e2c641bdb22..5c0e79eb9bb 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -45,8 +45,12 @@ Tensor& minimum_out( static constexpr const char op_name[] = "minimum.out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. return utils::min_override(val_a, val_b); }, ctx, @@ -54,8 +58,7 @@ Tensor& minimum_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 1ee73d342ca..6156227732d 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -52,7 +52,10 @@ Tensor& mul_out( out); ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a * val_b; }, @@ -61,8 +64,7 @@ Tensor& mul_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; @@ -95,13 +97,15 @@ Tensor& mul_scalar_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b](const auto val_a) { return val_a * val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/op_pow.cpp b/kernels/portable/cpu/op_pow.cpp index 81319b03d9f..4d2673cb72d 100644 --- a/kernels/portable/cpu/op_pow.cpp +++ b/kernels/portable/cpu/op_pow.cpp @@ -53,8 +53,12 @@ Tensor& pow_Tensor_Tensor_out( static constexpr const char op_name[] = "pow.Tensor_Tensor_out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. return std::pow(val_a, val_b); }, ctx, @@ -62,8 +66,7 @@ Tensor& pow_Tensor_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; @@ -104,13 +107,16 @@ Tensor& pow_Tensor_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( + // TODO: rewrite this to be vectorization-capable. [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; @@ -151,13 +157,16 @@ Tensor& pow_Scalar_out( ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_a = utils::scalar_to(a); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( + // TODO: rewrite this to be vectorization-capable. [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); }, ctx, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index d34c34a0380..01a5d72de01 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -53,9 +53,13 @@ Tensor& remainder_Tensor_out( bool div_by_zero_error = false; ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + // TODO: rewrite this to be vectorization-capable. CTYPE_COMPUTE value = 0; if (is_integral_type::value) { if (val_b == 0) { @@ -71,8 +75,7 @@ Tensor& remainder_Tensor_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); ET_KERNEL_CHECK_MSG( @@ -126,15 +129,18 @@ Tensor& remainder_Scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( [val_b](const CTYPE_COMPUTE val_a) { + // TODO: rewrite this to be vectorization-capable. return utils::remainder_override(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_rsub.cpp b/kernels/portable/cpu/op_rsub.cpp index 46af021efda..6a0a77b6596 100644 --- a/kernels/portable/cpu/op_rsub.cpp +++ b/kernels/portable/cpu/op_rsub.cpp @@ -52,15 +52,17 @@ Tensor& rsub_scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn( - [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b, val_alpha](const auto val_a) { return val_b - val_alpha * val_a; }, ctx, a, utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/op_sigmoid.cpp b/kernels/portable/cpu/op_sigmoid.cpp index 09cfed524f9..acb743a2db6 100644 --- a/kernels/portable/cpu/op_sigmoid.cpp +++ b/kernels/portable/cpu/op_sigmoid.cpp @@ -45,8 +45,12 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { static constexpr const char op_name[] = "sigmoid.out"; ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_unitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_in) { + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_in) -> CTYPE_COMPUTE { + // TODO: rewrite this to be vectorization-capable CTYPE_COMPUTE out_val = static_cast(1.0) / (static_cast(1.0) + exp(-val_in)); return out_val; @@ -54,8 +58,7 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ctx, in, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::FLOATHBF16); + out); }); return out; diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index 6217f82c3b1..aa90df8dee4 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -56,8 +56,11 @@ Tensor& sub_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_bitensor_elementwise_fn( - [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBF16>( + [val_alpha](const auto val_a, const auto val_b) { return val_a - val_alpha * val_b; }, ctx, @@ -65,8 +68,7 @@ Tensor& sub_out( utils::SupportedTensorDtypes::REALHBF16, b, utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); + out); }); return out; @@ -110,15 +112,17 @@ Tensor& sub_scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - utils::apply_unitensor_elementwise_fn( - [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [val_b, val_alpha](const auto val_a) { return val_a - val_alpha * val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index b455c45c2d1..692e296ee00 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -43,10 +43,13 @@ Tensor& where_out( static constexpr const char op_name[] = "where.self_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_tritensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, - const CTYPE_COMPUTE val_b, - const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, + utils::apply_tritensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [](const auto val_a, const auto val_b, const auto val_c) { + return val_c ? val_a : val_b; + }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, @@ -54,8 +57,7 @@ Tensor& where_out( utils::SupportedTensorDtypes::REALHBBF16, cond, utils::SupportedTensorDtypes::BOOL_OR_BYTE, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); return out; diff --git a/kernels/portable/cpu/pattern/bitwise_op.h b/kernels/portable/cpu/pattern/bitwise_op.h index 6e4c111b8f2..f78ce796e6c 100644 --- a/kernels/portable/cpu/pattern/bitwise_op.h +++ b/kernels/portable/cpu/pattern/bitwise_op.h @@ -80,15 +80,18 @@ Tensor& bitwise_tensor_out( ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + // TODO: rewrite this to be vectorization-capable. BitwiseFnForOp::value, ctx, a, utils::SupportedTensorDtypes::INTB, b, utils::SupportedTensorDtypes::INTB, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; @@ -121,16 +124,19 @@ Tensor& bitwise_scalar_out( ET_SWITCH_INT_TYPES_AND( Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [val_b](const CTYPE_COMPUTE val_a) { + // TODO: rewrite this to be vectorization-capable. return BitwiseFnForOp::value( val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::INTB, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/pattern/comparison_op.h b/kernels/portable/cpu/pattern/comparison_op.h index e0d9bf4dcab..643d7623922 100644 --- a/kernels/portable/cpu/pattern/comparison_op.h +++ b/kernels/portable/cpu/pattern/comparison_op.h @@ -91,15 +91,18 @@ Tensor& comparison_tensor_out( ScalarType compute_type = utils::get_compute_type(common_type); ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + // TODO: rewrite this to be vectorization-capable. ComparisonFnForOp::value, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; @@ -127,15 +130,18 @@ Tensor& comparison_scalar_out( ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( + utils::apply_unitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [val_b](const CTYPE_COMPUTE val_a) { + // TODO: rewrite this to be vectorization-capable. return ComparisonFnForOp::value(val_a, val_b); }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/pattern/logical_op.h b/kernels/portable/cpu/pattern/logical_op.h index 017822a85a6..4547d3df51b 100644 --- a/kernels/portable/cpu/pattern/logical_op.h +++ b/kernels/portable/cpu/pattern/logical_op.h @@ -34,15 +34,18 @@ Tensor& logical_tensor_out( InvalidArgument, out); - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + bool, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( + // TODO: rewrite this to be vectorization-capable. fn, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); return out; } diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index e3cac54908e..2286ca50bee 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -290,6 +290,25 @@ bool check_tensor_dtype( SupportedTensorDtypes dtypes, const ScalarType compute_type); +/// Return the one output type we are willing to emit specialized code +/// to handle, given a compute type of CTYPE_COMMON and supported +/// output types of out_dtypes. +template +inline constexpr ScalarType specialized_output_scalar_type( + SupportedTensorDtypes out_dtypes) { + switch (out_dtypes) { + case SupportedTensorDtypes::BOOL_OR_BYTE: + return ScalarType::Bool; + case SupportedTensorDtypes::REALHBBF16: + case SupportedTensorDtypes::REALHBF16: + case SupportedTensorDtypes::FLOATHBF16: + case SupportedTensorDtypes::INTB: + case SupportedTensorDtypes::SAME_AS_COMPUTE: + case SupportedTensorDtypes::SAME_AS_COMMON: + return CppTypeToScalarType::value; + } +} + } // namespace internal } // namespace utils } // namespace native diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 4e0718bc522..e30b8af7d89 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,6 +51,44 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { +template < + typename CTYPE_COMPUTE, + typename CTYPE_OUT, + typename Op, + typename... Args> +inline void dtype_specialized_elementwise_fn_impl( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + Args... inputs) { + constexpr auto kNumInputs = sizeof...(inputs); + ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); + + ::executorch::extension::parallel_for( + 0, + out.numel(), + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + std::array inputs_data_ptrs = { + inputs.first->template const_data_ptr()...}; + + CTYPE_OUT* const data_out = out.mutable_data_ptr(); + + const auto range = + BroadcastIndexesRange(out, (*inputs.first)...); + auto begin_it = range.begin(); + begin_it += begin; + for (; (*begin_it)[0] < end; ++begin_it) { + const auto& indexes = *begin_it; + std::array loaded_inputs; + for (const auto idx : c10::irange(kNumInputs)) { + loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]]; + } + data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs); + } + }); +} + template inline bool validate_elementwise_fn_inputs( const Op& compute_fun, @@ -81,18 +119,12 @@ template < const char* op_name, typename Op, typename... Args> -inline void apply_elementwise_fn( +inline void apply_elementwise_fn_generic_impl( const Op& compute_fun, KernelRuntimeContext& ctx, const Tensor& out, SupportedTensorDtypes out_dtypes, Args... inputs) { - const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); - if (!inputs_valid) { - return; - } - constexpr auto kNumInputs = sizeof...(inputs); struct InputInfo { @@ -138,6 +170,63 @@ inline void apply_elementwise_fn( }); } +template < + typename CTYPE_COMPUTE, + const char* op_name, + typename Op, + typename... Args> +inline void apply_elementwise_fn_runtime_out_dtypes( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + Args... inputs) { + const bool inputs_valid = validate_elementwise_fn_inputs( + compute_fun, ctx, out, out_dtypes, inputs...); + if (!inputs_valid) { + return; + } + + apply_elementwise_fn_generic_impl( + compute_fun, ctx, out, out_dtypes, inputs...); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op, + typename... Args> +inline void apply_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + Args... inputs) { + const bool inputs_valid = validate_elementwise_fn_inputs( + compute_fun, ctx, out, out_dtypes, inputs...); + if (!inputs_valid) { + return; + } + + constexpr auto compute_type = CppTypeToScalarType::value; + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl( + compute_fun, ctx, out, inputs...); + return; + } + + apply_elementwise_fn_generic_impl( + compute_fun, ctx, out, out_dtypes, inputs...); +} + /// DEPRECATED: prefer the variant with out_dtypes in the template argument. template inline void apply_unitensor_elementwise_fn( @@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn( SupportedTensorDtypes a_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); } @@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( - compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); + internal::apply_elementwise_fn( + compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); } /** @@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn( SupportedTensorDtypes b_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, @@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, - out_dtypes, std::make_pair(&a, a_dtypes), std::make_pair(&b, b_dtypes)); } @@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn( SupportedTensorDtypes c_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, @@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, - out_dtypes, std::make_pair(&a, a_dtypes), std::make_pair(&b, b_dtypes), std::make_pair(&c, c_dtypes));