diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 1ee73d342c..114e60ff17 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -52,7 +52,10 @@ Tensor& mul_out( out); ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::REALHBBF16>( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a * val_b; }, @@ -61,8 +64,7 @@ Tensor& mul_out( utils::SupportedTensorDtypes::REALHBBF16, b, utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); + out); }); return out; diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index e3cac54908..2286ca50be 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -290,6 +290,25 @@ bool check_tensor_dtype( SupportedTensorDtypes dtypes, const ScalarType compute_type); +/// Return the one output type we are willing to emit specialized code +/// to handle, given a compute type of CTYPE_COMMON and supported +/// output types of out_dtypes. +template +inline constexpr ScalarType specialized_output_scalar_type( + SupportedTensorDtypes out_dtypes) { + switch (out_dtypes) { + case SupportedTensorDtypes::BOOL_OR_BYTE: + return ScalarType::Bool; + case SupportedTensorDtypes::REALHBBF16: + case SupportedTensorDtypes::REALHBF16: + case SupportedTensorDtypes::FLOATHBF16: + case SupportedTensorDtypes::INTB: + case SupportedTensorDtypes::SAME_AS_COMPUTE: + case SupportedTensorDtypes::SAME_AS_COMMON: + return CppTypeToScalarType::value; + } +} + } // namespace internal } // namespace utils } // namespace native diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 4e0718bc52..e30b8af7d8 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,6 +51,44 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { +template < + typename CTYPE_COMPUTE, + typename CTYPE_OUT, + typename Op, + typename... Args> +inline void dtype_specialized_elementwise_fn_impl( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + Args... inputs) { + constexpr auto kNumInputs = sizeof...(inputs); + ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); + + ::executorch::extension::parallel_for( + 0, + out.numel(), + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + std::array inputs_data_ptrs = { + inputs.first->template const_data_ptr()...}; + + CTYPE_OUT* const data_out = out.mutable_data_ptr(); + + const auto range = + BroadcastIndexesRange(out, (*inputs.first)...); + auto begin_it = range.begin(); + begin_it += begin; + for (; (*begin_it)[0] < end; ++begin_it) { + const auto& indexes = *begin_it; + std::array loaded_inputs; + for (const auto idx : c10::irange(kNumInputs)) { + loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]]; + } + data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs); + } + }); +} + template inline bool validate_elementwise_fn_inputs( const Op& compute_fun, @@ -81,18 +119,12 @@ template < const char* op_name, typename Op, typename... Args> -inline void apply_elementwise_fn( +inline void apply_elementwise_fn_generic_impl( const Op& compute_fun, KernelRuntimeContext& ctx, const Tensor& out, SupportedTensorDtypes out_dtypes, Args... inputs) { - const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); - if (!inputs_valid) { - return; - } - constexpr auto kNumInputs = sizeof...(inputs); struct InputInfo { @@ -138,6 +170,63 @@ inline void apply_elementwise_fn( }); } +template < + typename CTYPE_COMPUTE, + const char* op_name, + typename Op, + typename... Args> +inline void apply_elementwise_fn_runtime_out_dtypes( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + Args... inputs) { + const bool inputs_valid = validate_elementwise_fn_inputs( + compute_fun, ctx, out, out_dtypes, inputs...); + if (!inputs_valid) { + return; + } + + apply_elementwise_fn_generic_impl( + compute_fun, ctx, out, out_dtypes, inputs...); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op, + typename... Args> +inline void apply_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& out, + Args... inputs) { + const bool inputs_valid = validate_elementwise_fn_inputs( + compute_fun, ctx, out, out_dtypes, inputs...); + if (!inputs_valid) { + return; + } + + constexpr auto compute_type = CppTypeToScalarType::value; + const bool all_inputs_compute_dtype = + ((inputs.first->scalar_type() == compute_type) && ...); + + constexpr ScalarType out_specialized_scalar_type = + specialized_output_scalar_type(out_dtypes); + if (all_inputs_compute_dtype && + out.scalar_type() == out_specialized_scalar_type) { + using CTYPE_OUT = + typename ScalarTypeToCppType::type; + dtype_specialized_elementwise_fn_impl( + compute_fun, ctx, out, inputs...); + return; + } + + apply_elementwise_fn_generic_impl( + compute_fun, ctx, out, out_dtypes, inputs...); +} + /// DEPRECATED: prefer the variant with out_dtypes in the template argument. template inline void apply_unitensor_elementwise_fn( @@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn( SupportedTensorDtypes a_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); } @@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( - compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); + internal::apply_elementwise_fn( + compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); } /** @@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn( SupportedTensorDtypes b_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, @@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, - out_dtypes, std::make_pair(&a, a_dtypes), std::make_pair(&b, b_dtypes)); } @@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn( SupportedTensorDtypes c_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn_runtime_out_dtypes( compute_fun, ctx, out, @@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, - out_dtypes, std::make_pair(&a, a_dtypes), std::make_pair(&b, b_dtypes), std::make_pair(&c, c_dtypes));