Skip to content

Commit 979e8e9

Browse files
committed
Refactor elementwise_util: create variants with out_dtypes in template argument list
ghstack-source-id: 03f1860 ghstack-comment-id: 2735017483 Pull Request resolved: #9387
1 parent 3b7c86c commit 979e8e9

File tree

1 file changed

+91
-10
lines changed

1 file changed

+91
-10
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5353
namespace internal {
5454
template <
5555
typename CTYPE_COMMON,
56-
const char* op_name,
5756
typename Op,
58-
typename... Args>
59-
inline void apply_elementwise_fn(
57+
typename... Args>
58+
inline bool validate_elementwise_fn_inputs(
6059
const Op& compute_fun,
6160
KernelRuntimeContext& ctx,
6261
const Tensor& out,
@@ -65,7 +64,6 @@ inline void apply_elementwise_fn(
6564
static_assert(
6665
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
6766
...));
68-
constexpr auto kNumInputs = sizeof...(inputs);
6967
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
7068
const auto check_input_dtype = [](auto input, auto compute_type) {
7169
return internal::check_tensor_dtype(
@@ -75,7 +73,33 @@ inline void apply_elementwise_fn(
7573
ctx,
7674
(check_input_dtype(inputs, compute_type) && ...) &&
7775
internal::check_tensor_dtype(out, out_dtypes, compute_type),
78-
InvalidArgument, );
76+
InvalidArgument, false);
77+
78+
return true;
79+
}
80+
81+
template <
82+
typename CTYPE_COMMON,
83+
const char* op_name,
84+
typename Op,
85+
typename... Args>
86+
inline void apply_elementwise_fn(
87+
const Op& compute_fun,
88+
KernelRuntimeContext& ctx,
89+
const Tensor& out,
90+
SupportedTensorDtypes out_dtypes,
91+
Args... inputs) {
92+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
93+
compute_fun,
94+
ctx,
95+
out,
96+
out_dtypes,
97+
inputs...);
98+
if (!inputs_valid) {
99+
return;
100+
}
101+
102+
constexpr auto kNumInputs = sizeof...(inputs);
79103

80104
struct InputInfo {
81105
load_to_common_fn<CTYPE_COMMON> load_to_common;
@@ -120,6 +144,7 @@ inline void apply_elementwise_fn(
120144
}
121145
} // namespace internal
122146

147+
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
123148
template <typename CTYPE_COMMON, const char* op_name, typename Op>
124149
inline void apply_unitensor_elementwise_fn(
125150
const Op& compute_fun,
@@ -132,19 +157,75 @@ inline void apply_unitensor_elementwise_fn(
132157
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
133158
}
134159

160+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
161+
inline void apply_unitensor_elementwise_fn(
162+
const Op& compute_fun,
163+
KernelRuntimeContext& ctx,
164+
const Tensor& a,
165+
SupportedTensorDtypes a_dtypes,
166+
const Tensor& out) {
167+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
168+
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
169+
}
170+
171+
/**
172+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
173+
*/
174+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
175+
inline void apply_bitensor_elementwise_fn(
176+
const Op& compute_fun,
177+
KernelRuntimeContext& ctx,
178+
const Tensor& a,
179+
SupportedTensorDtypes a_dtypes,
180+
const Tensor& b,
181+
SupportedTensorDtypes b_dtypes,
182+
const Tensor& out,
183+
SupportedTensorDtypes out_dtypes) {
184+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
185+
compute_fun,
186+
ctx,
187+
out,
188+
out_dtypes,
189+
std::make_pair(&a, a_dtypes),
190+
std::make_pair(&b, b_dtypes));
191+
}
192+
135193
/**
136194
* Useful for bi-tensor elementwise operators. For each element of the inputs,
137195
* perform a computation and write to the corresponding element of the output.
138196
* Tensor broadcasting is applied wherever it is required.
139197
*/
140-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
198+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
141199
inline void apply_bitensor_elementwise_fn(
142200
const Op& compute_fun,
143201
KernelRuntimeContext& ctx,
144202
const Tensor& a,
145203
SupportedTensorDtypes a_dtypes,
146204
const Tensor& b,
147205
SupportedTensorDtypes b_dtypes,
206+
const Tensor& out) {
207+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
208+
compute_fun,
209+
ctx,
210+
out,
211+
out_dtypes,
212+
std::make_pair(&a, a_dtypes),
213+
std::make_pair(&b, b_dtypes));
214+
}
215+
216+
/**
217+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
218+
*/
219+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
220+
inline void apply_tritensor_elementwise_fn(
221+
const Op& compute_fun,
222+
KernelRuntimeContext& ctx,
223+
const Tensor& a,
224+
SupportedTensorDtypes a_dtypes,
225+
const Tensor& b,
226+
SupportedTensorDtypes b_dtypes,
227+
const Tensor& c,
228+
SupportedTensorDtypes c_dtypes,
148229
const Tensor& out,
149230
SupportedTensorDtypes out_dtypes) {
150231
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
@@ -153,7 +234,8 @@ inline void apply_bitensor_elementwise_fn(
153234
out,
154235
out_dtypes,
155236
std::make_pair(&a, a_dtypes),
156-
std::make_pair(&b, b_dtypes));
237+
std::make_pair(&b, b_dtypes),
238+
std::make_pair(&c, c_dtypes));
157239
}
158240

159241
/**
@@ -176,7 +258,7 @@ inline void apply_bitensor_elementwise_fn(
176258
* static constexpr const char op_name[] = "my_op";
177259
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
178260
*/
179-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
261+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
180262
inline void apply_tritensor_elementwise_fn(
181263
const Op& compute_fun,
182264
KernelRuntimeContext& ctx,
@@ -186,8 +268,7 @@ inline void apply_tritensor_elementwise_fn(
186268
SupportedTensorDtypes b_dtypes,
187269
const Tensor& c,
188270
SupportedTensorDtypes c_dtypes,
189-
const Tensor& out,
190-
SupportedTensorDtypes out_dtypes) {
271+
const Tensor& out) {
191272
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
192273
compute_fun,
193274
ctx,

0 commit comments

Comments
 (0)