Skip to content

Commit a8a1c5c

Browse files
committed
Refactor elementwise_util: create variants with out_dtypes in template argument list
ghstack-source-id: 33626ffef1bfd507c29f06dd3f7f7572dfda502b ghstack-comment-id: 2735017483 Pull Request resolved: #9387
1 parent 9fe1283 commit a8a1c5c

File tree

1 file changed

+91
-10
lines changed

1 file changed

+91
-10
lines changed

kernels/portable/cpu/util/elementwise_util.h

+91-10
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ using op_call_result =
6060

6161
template <
6262
typename CTYPE_COMMON,
63-
const char* op_name,
6463
typename Op,
65-
typename... Args>
66-
inline void apply_elementwise_fn(
64+
typename... Args>
65+
inline bool validate_elementwise_fn_inputs(
6766
const Op& compute_fun,
6867
KernelRuntimeContext& ctx,
6968
const Tensor& out,
@@ -72,7 +71,6 @@ inline void apply_elementwise_fn(
7271
static_assert(
7372
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
7473
...));
75-
constexpr auto kNumInputs = sizeof...(inputs);
7674
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
7775
const auto check_input_dtype = [](auto input, auto compute_type) {
7876
return internal::check_tensor_dtype(
@@ -82,7 +80,33 @@ inline void apply_elementwise_fn(
8280
ctx,
8381
(check_input_dtype(inputs, compute_type) && ...) &&
8482
internal::check_tensor_dtype(out, out_dtypes, compute_type),
85-
InvalidArgument, );
83+
InvalidArgument, false);
84+
85+
return true;
86+
}
87+
88+
template <
89+
typename CTYPE_COMMON,
90+
const char* op_name,
91+
typename Op,
92+
typename... Args>
93+
inline void apply_elementwise_fn(
94+
const Op& compute_fun,
95+
KernelRuntimeContext& ctx,
96+
const Tensor& out,
97+
SupportedTensorDtypes out_dtypes,
98+
Args... inputs) {
99+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100+
compute_fun,
101+
ctx,
102+
out,
103+
out_dtypes,
104+
inputs...);
105+
if (!inputs_valid) {
106+
return;
107+
}
108+
109+
constexpr auto kNumInputs = sizeof...(inputs);
86110

87111
struct InputInfo {
88112
load_to_common_fn<CTYPE_COMMON> load_to_common;
@@ -135,6 +159,7 @@ inline void apply_elementwise_fn(
135159
}
136160
} // namespace internal
137161

162+
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
138163
template <typename CTYPE_COMMON, const char* op_name, typename Op>
139164
inline void apply_unitensor_elementwise_fn(
140165
const Op& compute_fun,
@@ -147,19 +172,75 @@ inline void apply_unitensor_elementwise_fn(
147172
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
148173
}
149174

175+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
176+
inline void apply_unitensor_elementwise_fn(
177+
const Op& compute_fun,
178+
KernelRuntimeContext& ctx,
179+
const Tensor& a,
180+
SupportedTensorDtypes a_dtypes,
181+
const Tensor& out) {
182+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183+
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
184+
}
185+
186+
/**
187+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
188+
*/
189+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
190+
inline void apply_bitensor_elementwise_fn(
191+
const Op& compute_fun,
192+
KernelRuntimeContext& ctx,
193+
const Tensor& a,
194+
SupportedTensorDtypes a_dtypes,
195+
const Tensor& b,
196+
SupportedTensorDtypes b_dtypes,
197+
const Tensor& out,
198+
SupportedTensorDtypes out_dtypes) {
199+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
200+
compute_fun,
201+
ctx,
202+
out,
203+
out_dtypes,
204+
std::make_pair(&a, a_dtypes),
205+
std::make_pair(&b, b_dtypes));
206+
}
207+
150208
/**
151209
* Useful for bi-tensor elementwise operators. For each element of the inputs,
152210
* perform a computation and write to the corresponding element of the output.
153211
* Tensor broadcasting is applied wherever it is required.
154212
*/
155-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
213+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
156214
inline void apply_bitensor_elementwise_fn(
157215
const Op& compute_fun,
158216
KernelRuntimeContext& ctx,
159217
const Tensor& a,
160218
SupportedTensorDtypes a_dtypes,
161219
const Tensor& b,
162220
SupportedTensorDtypes b_dtypes,
221+
const Tensor& out) {
222+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
223+
compute_fun,
224+
ctx,
225+
out,
226+
out_dtypes,
227+
std::make_pair(&a, a_dtypes),
228+
std::make_pair(&b, b_dtypes));
229+
}
230+
231+
/**
232+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
233+
*/
234+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
235+
inline void apply_tritensor_elementwise_fn(
236+
const Op& compute_fun,
237+
KernelRuntimeContext& ctx,
238+
const Tensor& a,
239+
SupportedTensorDtypes a_dtypes,
240+
const Tensor& b,
241+
SupportedTensorDtypes b_dtypes,
242+
const Tensor& c,
243+
SupportedTensorDtypes c_dtypes,
163244
const Tensor& out,
164245
SupportedTensorDtypes out_dtypes) {
165246
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
@@ -168,7 +249,8 @@ inline void apply_bitensor_elementwise_fn(
168249
out,
169250
out_dtypes,
170251
std::make_pair(&a, a_dtypes),
171-
std::make_pair(&b, b_dtypes));
252+
std::make_pair(&b, b_dtypes),
253+
std::make_pair(&c, c_dtypes));
172254
}
173255

174256
/**
@@ -191,7 +273,7 @@ inline void apply_bitensor_elementwise_fn(
191273
* static constexpr const char op_name[] = "my_op";
192274
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
193275
*/
194-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
276+
template <typename CTYPE_COMMON, const char* op_name, SupportedTensorDtypes out_dtypes, typename Op>
195277
inline void apply_tritensor_elementwise_fn(
196278
const Op& compute_fun,
197279
KernelRuntimeContext& ctx,
@@ -201,8 +283,7 @@ inline void apply_tritensor_elementwise_fn(
201283
SupportedTensorDtypes b_dtypes,
202284
const Tensor& c,
203285
SupportedTensorDtypes c_dtypes,
204-
const Tensor& out,
205-
SupportedTensorDtypes out_dtypes) {
286+
const Tensor& out) {
206287
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
207288
compute_fun,
208289
ctx,

0 commit comments

Comments
 (0)