From 3f1b775fe481d9d9d88896c913f7033dc3cfd21d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Mar 2025 09:51:43 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/util/dtype_util.h | 35 ++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index b5cd980b085..eb1ee83111e 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -228,7 +228,7 @@ enum class SupportedTensorDtypes { namespace internal { template -load_to_compute_fn get_load_to_compute_fn( +load_to_compute_fn get_load_to_compute_fn_impl( const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { @@ -252,7 +252,7 @@ load_to_compute_fn get_load_to_compute_fn( } template -store_compute_to_tensor_fn get_store_compute_to_tensor_fn( +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_impl( const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { @@ -285,6 +285,37 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( return nullptr; } +#ifndef EXECUTORCH_SELECTIVE_BUILD_DTYPE +constexpr const char kGenericElementwiseOpName[] = "generic_elementwise_op"; +#endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE + +template +load_to_compute_fn get_load_to_compute_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + return get_load_to_compute_fn_impl< + CTYPE_COMPUTE, +#ifdef EXECUTORCH_SELECTIVE_BUILD_DTYPE + op_name +#else // EXECUTORCH_SELECTIVE_BUILD_DTYPE + kGenericElementwiseOpName +#endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE + >(t, dtypes); +} + +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + return get_store_compute_to_tensor_fn_impl< + CTYPE_COMPUTE, +#ifdef EXECUTORCH_SELECTIVE_BUILD_DTYPE + op_name +#else // EXECUTORCH_SELECTIVE_BUILD_DTYPE + kGenericElementwiseOpName +#endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE + >(t, dtypes); +} bool check_tensor_dtype( const Tensor t, SupportedTensorDtypes dtypes,