|
6 | 6 | * LICENSE file in the root directory of this source tree.
|
7 | 7 | */
|
8 | 8 |
|
9 |
| -#include <executorch/runtime/kernel/kernel_includes.h> |
10 |
| - |
11 | 9 | #include <executorch/kernels/optimized/blas/CPUBlas.h>
|
| 10 | +#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h> |
| 11 | +#include <executorch/runtime/kernel/kernel_includes.h> |
12 | 12 |
|
13 | 13 | // Performs a batch matrix-matrix product of matrices stored in input and mat2.
|
14 | 14 |
|
@@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {
|
136 | 136 |
|
137 | 137 | // bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
138 | 138 | Tensor& opt_bmm_out(
|
139 |
| - KernelRuntimeContext& context, |
| 139 | + KernelRuntimeContext& ctx, |
140 | 140 | const Tensor& self,
|
141 | 141 | const Tensor& mat2,
|
142 | 142 | Tensor& out) {
|
143 |
| - (void)context; |
| 143 | + (void)ctx; |
144 | 144 |
|
145 | 145 | ET_KERNEL_CHECK(
|
146 |
| - context, |
| 146 | + ctx, |
147 | 147 | resize_out_tensor(self, mat2, out) == Error::Ok,
|
148 | 148 | InvalidArgument,
|
149 | 149 | out);
|
150 | 150 | ET_KERNEL_CHECK(
|
151 |
| - context, check_bmm_out_args(self, mat2, out), InvalidArgument, out); |
152 |
| - |
153 |
| -#define BMM_TENSOR(ctype, dtype) \ |
154 |
| - case ScalarType::dtype: \ |
155 |
| - bmm_kernel<ctype>(self, mat2, out); \ |
156 |
| - break; |
157 |
| - |
158 |
| - auto scalar_type = self.scalar_type(); |
159 |
| - switch (scalar_type) { |
160 |
| - ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR) |
161 |
| - default: |
162 |
| - ET_CHECK_MSG( |
163 |
| - false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type)); |
| 151 | + ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out); |
| 152 | + |
| 153 | + constexpr auto name = "bmm.out"; |
| 154 | + auto self_type = self.scalar_type(); |
| 155 | + |
| 156 | + if (executorch::runtime::isComplexType(self_type)) { |
| 157 | + ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() { |
| 158 | + internal::bmm_out_impl<CTYPE>(self, mat2, out); |
| 159 | + }); |
| 160 | + } else { |
| 161 | + ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() { |
| 162 | + bmm_kernel<CTYPE>(self, mat2, out); |
| 163 | + }); |
164 | 164 | }
|
165 |
| -#undef BMM_TENSOR |
166 | 165 |
|
167 | 166 | return out;
|
168 | 167 | }
|
|
0 commit comments