@@ -60,10 +60,9 @@ using op_call_result =
60
60
61
61
template <
62
62
typename CTYPE_COMMON,
63
- const char * op_name,
64
63
typename Op,
65
- typename ... Args>
66
- inline void apply_elementwise_fn (
64
+ typename ... Args>
65
+ inline bool validate_elementwise_fn_inputs (
67
66
const Op& compute_fun,
68
67
KernelRuntimeContext& ctx,
69
68
const Tensor& out,
@@ -72,7 +71,6 @@ inline void apply_elementwise_fn(
72
71
static_assert (
73
72
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
74
73
...));
75
- constexpr auto kNumInputs = sizeof ...(inputs);
76
74
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
77
75
const auto check_input_dtype = [](auto input, auto compute_type) {
78
76
return internal::check_tensor_dtype (
@@ -82,7 +80,33 @@ inline void apply_elementwise_fn(
82
80
ctx,
83
81
(check_input_dtype (inputs, compute_type) && ...) &&
84
82
internal::check_tensor_dtype (out, out_dtypes, compute_type),
85
- InvalidArgument, );
83
+ InvalidArgument, false );
84
+
85
+ return true ;
86
+ }
87
+
88
+ template <
89
+ typename CTYPE_COMMON,
90
+ const char * op_name,
91
+ typename Op,
92
+ typename ... Args>
93
+ inline void apply_elementwise_fn (
94
+ const Op& compute_fun,
95
+ KernelRuntimeContext& ctx,
96
+ const Tensor& out,
97
+ SupportedTensorDtypes out_dtypes,
98
+ Args... inputs) {
99
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100
+ compute_fun,
101
+ ctx,
102
+ out,
103
+ out_dtypes,
104
+ inputs...);
105
+ if (!inputs_valid) {
106
+ return ;
107
+ }
108
+
109
+ constexpr auto kNumInputs = sizeof ...(inputs);
86
110
87
111
struct InputInfo {
88
112
load_to_common_fn<CTYPE_COMMON> load_to_common;
@@ -135,6 +159,7 @@ inline void apply_elementwise_fn(
135
159
}
136
160
} // namespace internal
137
161
162
+ // / DEPRECATED: prefer the variant with out_dtypes in the template argument.
138
163
template <typename CTYPE_COMMON, const char * op_name, typename Op>
139
164
inline void apply_unitensor_elementwise_fn (
140
165
const Op& compute_fun,
@@ -147,19 +172,75 @@ inline void apply_unitensor_elementwise_fn(
147
172
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
148
173
}
149
174
175
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
176
+ inline void apply_unitensor_elementwise_fn (
177
+ const Op& compute_fun,
178
+ KernelRuntimeContext& ctx,
179
+ const Tensor& a,
180
+ SupportedTensorDtypes a_dtypes,
181
+ const Tensor& out) {
182
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183
+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
184
+ }
185
+
186
+ /* *
187
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
188
+ */
189
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
190
+ inline void apply_bitensor_elementwise_fn (
191
+ const Op& compute_fun,
192
+ KernelRuntimeContext& ctx,
193
+ const Tensor& a,
194
+ SupportedTensorDtypes a_dtypes,
195
+ const Tensor& b,
196
+ SupportedTensorDtypes b_dtypes,
197
+ const Tensor& out,
198
+ SupportedTensorDtypes out_dtypes) {
199
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
200
+ compute_fun,
201
+ ctx,
202
+ out,
203
+ out_dtypes,
204
+ std::make_pair (&a, a_dtypes),
205
+ std::make_pair (&b, b_dtypes));
206
+ }
207
+
150
208
/* *
151
209
* Useful for bi-tensor elementwise operators. For each element of the inputs,
152
210
* perform a computation and write to the corresponding element of the output.
153
211
* Tensor broadcasting is applied wherever it is required.
154
212
*/
155
- template <typename CTYPE_COMMON, const char * op_name, typename Op>
213
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
156
214
inline void apply_bitensor_elementwise_fn (
157
215
const Op& compute_fun,
158
216
KernelRuntimeContext& ctx,
159
217
const Tensor& a,
160
218
SupportedTensorDtypes a_dtypes,
161
219
const Tensor& b,
162
220
SupportedTensorDtypes b_dtypes,
221
+ const Tensor& out) {
222
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
223
+ compute_fun,
224
+ ctx,
225
+ out,
226
+ out_dtypes,
227
+ std::make_pair (&a, a_dtypes),
228
+ std::make_pair (&b, b_dtypes));
229
+ }
230
+
231
+ /* *
232
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
233
+ */
234
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
235
+ inline void apply_tritensor_elementwise_fn (
236
+ const Op& compute_fun,
237
+ KernelRuntimeContext& ctx,
238
+ const Tensor& a,
239
+ SupportedTensorDtypes a_dtypes,
240
+ const Tensor& b,
241
+ SupportedTensorDtypes b_dtypes,
242
+ const Tensor& c,
243
+ SupportedTensorDtypes c_dtypes,
163
244
const Tensor& out,
164
245
SupportedTensorDtypes out_dtypes) {
165
246
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
@@ -168,7 +249,8 @@ inline void apply_bitensor_elementwise_fn(
168
249
out,
169
250
out_dtypes,
170
251
std::make_pair (&a, a_dtypes),
171
- std::make_pair (&b, b_dtypes));
252
+ std::make_pair (&b, b_dtypes),
253
+ std::make_pair (&c, c_dtypes));
172
254
}
173
255
174
256
/* *
@@ -191,7 +273,7 @@ inline void apply_bitensor_elementwise_fn(
191
273
* static constexpr const char op_name[] = "my_op";
192
274
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
193
275
*/
194
- template <typename CTYPE_COMMON, const char * op_name, typename Op>
276
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
195
277
inline void apply_tritensor_elementwise_fn (
196
278
const Op& compute_fun,
197
279
KernelRuntimeContext& ctx,
@@ -201,8 +283,7 @@ inline void apply_tritensor_elementwise_fn(
201
283
SupportedTensorDtypes b_dtypes,
202
284
const Tensor& c,
203
285
SupportedTensorDtypes c_dtypes,
204
- const Tensor& out,
205
- SupportedTensorDtypes out_dtypes) {
286
+ const Tensor& out) {
206
287
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
207
288
compute_fun,
208
289
ctx,
0 commit comments