@@ -53,10 +53,9 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
53
53
namespace internal {
54
54
template <
55
55
typename CTYPE_COMMON,
56
- const char * op_name,
57
56
typename Op,
58
- typename ... Args>
59
- inline void apply_elementwise_fn (
57
+ typename ... Args>
58
+ inline bool validate_elementwise_fn_inputs (
60
59
const Op& compute_fun,
61
60
KernelRuntimeContext& ctx,
62
61
const Tensor& out,
@@ -65,7 +64,6 @@ inline void apply_elementwise_fn(
65
64
static_assert (
66
65
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
67
66
...));
68
- constexpr auto kNumInputs = sizeof ...(inputs);
69
67
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
70
68
const auto check_input_dtype = [](auto input, auto compute_type) {
71
69
return internal::check_tensor_dtype (
@@ -75,7 +73,33 @@ inline void apply_elementwise_fn(
75
73
ctx,
76
74
(check_input_dtype (inputs, compute_type) && ...) &&
77
75
internal::check_tensor_dtype (out, out_dtypes, compute_type),
78
- InvalidArgument, );
76
+ InvalidArgument, false );
77
+
78
+ return true ;
79
+ }
80
+
81
+ template <
82
+ typename CTYPE_COMMON,
83
+ const char * op_name,
84
+ typename Op,
85
+ typename ... Args>
86
+ inline void apply_elementwise_fn (
87
+ const Op& compute_fun,
88
+ KernelRuntimeContext& ctx,
89
+ const Tensor& out,
90
+ SupportedTensorDtypes out_dtypes,
91
+ Args... inputs) {
92
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
93
+ compute_fun,
94
+ ctx,
95
+ out,
96
+ out_dtypes,
97
+ inputs...);
98
+ if (!inputs_valid) {
99
+ return ;
100
+ }
101
+
102
+ constexpr auto kNumInputs = sizeof ...(inputs);
79
103
80
104
struct InputInfo {
81
105
load_to_common_fn<CTYPE_COMMON> load_to_common;
@@ -120,6 +144,7 @@ inline void apply_elementwise_fn(
120
144
}
121
145
} // namespace internal
122
146
147
+ // / DEPRECATED: prefer the variant with out_dtypes in the template argument.
123
148
template <typename CTYPE_COMMON, const char * op_name, typename Op>
124
149
inline void apply_unitensor_elementwise_fn (
125
150
const Op& compute_fun,
@@ -132,19 +157,75 @@ inline void apply_unitensor_elementwise_fn(
132
157
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
133
158
}
134
159
160
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
161
+ inline void apply_unitensor_elementwise_fn (
162
+ const Op& compute_fun,
163
+ KernelRuntimeContext& ctx,
164
+ const Tensor& a,
165
+ SupportedTensorDtypes a_dtypes,
166
+ const Tensor& out) {
167
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
168
+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
169
+ }
170
+
171
+ /* *
172
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
173
+ */
174
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
175
+ inline void apply_bitensor_elementwise_fn (
176
+ const Op& compute_fun,
177
+ KernelRuntimeContext& ctx,
178
+ const Tensor& a,
179
+ SupportedTensorDtypes a_dtypes,
180
+ const Tensor& b,
181
+ SupportedTensorDtypes b_dtypes,
182
+ const Tensor& out,
183
+ SupportedTensorDtypes out_dtypes) {
184
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
185
+ compute_fun,
186
+ ctx,
187
+ out,
188
+ out_dtypes,
189
+ std::make_pair (&a, a_dtypes),
190
+ std::make_pair (&b, b_dtypes));
191
+ }
192
+
135
193
/* *
136
194
* Useful for bi-tensor elementwise operators. For each element of the inputs,
137
195
* perform a computation and write to the corresponding element of the output.
138
196
* Tensor broadcasting is applied wherever it is required.
139
197
*/
140
- template <typename CTYPE_COMMON, const char * op_name, typename Op>
198
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
141
199
inline void apply_bitensor_elementwise_fn (
142
200
const Op& compute_fun,
143
201
KernelRuntimeContext& ctx,
144
202
const Tensor& a,
145
203
SupportedTensorDtypes a_dtypes,
146
204
const Tensor& b,
147
205
SupportedTensorDtypes b_dtypes,
206
+ const Tensor& out) {
207
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
208
+ compute_fun,
209
+ ctx,
210
+ out,
211
+ out_dtypes,
212
+ std::make_pair (&a, a_dtypes),
213
+ std::make_pair (&b, b_dtypes));
214
+ }
215
+
216
+ /* *
217
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
218
+ */
219
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
220
+ inline void apply_tritensor_elementwise_fn (
221
+ const Op& compute_fun,
222
+ KernelRuntimeContext& ctx,
223
+ const Tensor& a,
224
+ SupportedTensorDtypes a_dtypes,
225
+ const Tensor& b,
226
+ SupportedTensorDtypes b_dtypes,
227
+ const Tensor& c,
228
+ SupportedTensorDtypes c_dtypes,
148
229
const Tensor& out,
149
230
SupportedTensorDtypes out_dtypes) {
150
231
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
@@ -153,7 +234,8 @@ inline void apply_bitensor_elementwise_fn(
153
234
out,
154
235
out_dtypes,
155
236
std::make_pair (&a, a_dtypes),
156
- std::make_pair (&b, b_dtypes));
237
+ std::make_pair (&b, b_dtypes),
238
+ std::make_pair (&c, c_dtypes));
157
239
}
158
240
159
241
/* *
@@ -176,7 +258,7 @@ inline void apply_bitensor_elementwise_fn(
176
258
* static constexpr const char op_name[] = "my_op";
177
259
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
178
260
*/
179
- template <typename CTYPE_COMMON, const char * op_name, typename Op>
261
+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
180
262
inline void apply_tritensor_elementwise_fn (
181
263
const Op& compute_fun,
182
264
KernelRuntimeContext& ctx,
@@ -186,8 +268,7 @@ inline void apply_tritensor_elementwise_fn(
186
268
SupportedTensorDtypes b_dtypes,
187
269
const Tensor& c,
188
270
SupportedTensorDtypes c_dtypes,
189
- const Tensor& out,
190
- SupportedTensorDtypes out_dtypes) {
271
+ const Tensor& out) {
191
272
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
192
273
compute_fun,
193
274
ctx,
0 commit comments