Skip to content

Commit e2cd7f3

Browse files
committed
elementwise_util: don't cast the result of compute_fun back to the common type
The compute function might return an entirely different type. For example, if we were applying a trigonometric function like acos to an input of type bool expecting an output of type float, we would get bad results because acos(0) = 1.57, but casting through bool would truncate that to 1. Note that we don't need the pair of ET_CHECK_MSG I removed because we already check tensor dtypes on entry to the elementwise util functions; the checks were inconvenient because we now call get_store_common_to_tensor_fn without the actual common type. ghstack-source-id: cfcbe8b142102d2b028c44229c109fda63491b0d ghstack-comment-id: 2735017325 Pull Request resolved: #9385
1 parent 5b57eb1 commit e2cd7f3

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

Diff for: kernels/portable/cpu/util/dtype_util.h

-12
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,6 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
8686
template <typename CTYPE_COMMON, const char* op_name>
8787
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
8888
const Tensor& t) {
89-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
90-
ET_CHECK_MSG(
91-
t.scalar_type() == common_scalar_type,
92-
"Unhandled dtype %s for %s",
93-
::executorch::runtime::toString(common_scalar_type),
94-
op_name);
9589
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
9690
}
9791

@@ -179,12 +173,6 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
179173
template <typename CTYPE_COMMON, const char* op_name>
180174
store_common_to_tensor_fn<CTYPE_COMMON>
181175
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
183-
ET_CHECK_MSG(
184-
t.scalar_type() == common_scalar_type,
185-
"Unhandled dtype %s for %s",
186-
::executorch::runtime::toString(common_scalar_type),
187-
op_name);
188176
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
189177
}
190178

Diff for: kernels/portable/cpu/util/elementwise_util.h

+19-4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151
}
5252

5353
namespace internal {
54+
template <typename Ignore, typename T>
55+
using ignore_first_yield_second = T;
56+
57+
template <typename CTYPE_COMMON, typename Op, typename... Args>
58+
using op_call_result =
59+
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60+
5461
template <
5562
typename CTYPE_COMMON,
5663
const char* op_name,
@@ -89,9 +96,16 @@ inline void apply_elementwise_fn(
8996
inputs.first->element_size(),
9097
})...};
9198

92-
const auto store_common_to_out =
93-
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
94-
out, out_dtypes);
99+
// NOTE: the result of compute_fun is not necessarily CTYPE_COMMON!
100+
// For example, consider the possibility that compute_fun is a
101+
// trigonometric function like acos, the common input type is bool,
102+
// and the output type is float -- we would truncate acos(0) ~= 1.67
103+
// to just 1. Conveniently, it costs us nothing at runtime to handle
104+
// this correctly.
105+
const auto store_compute_result_to_out =
106+
internal::get_store_common_to_tensor_fn<
107+
op_call_result<CTYPE_COMMON, Op, Args...>,
108+
op_name>(out, out_dtypes);
95109
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
96110
const auto out_element_size = out.element_size();
97111

@@ -114,7 +128,8 @@ inline void apply_elementwise_fn(
114128
.data_ptr[indexes[idx + 1] * input_info.element_size]);
115129
}
116130
auto result = std::apply(compute_fun, loaded_inputs);
117-
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
131+
store_compute_result_to_out(
132+
result, &data_out[indexes[0] * out_element_size]);
118133
}
119134
});
120135
}

0 commit comments

Comments
 (0)