Skip to content

Commit ae2d822

Browse files
Complex Support: bmm
Differential Revision: D72986238 Pull Request resolved: #10197
1 parent c803f30 commit ae2d822

File tree

6 files changed

+127
-40
lines changed

6 files changed

+127
-40
lines changed

kernels/optimized/cpu/op_bmm.cpp

+18-19
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/runtime/kernel/kernel_includes.h>
10-
119
#include <executorch/kernels/optimized/blas/CPUBlas.h>
10+
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
1212

1313
// Performs a batch matrix-matrix product of matrices stored in input and mat2.
1414

@@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {
136136

137137
// bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
138138
Tensor& opt_bmm_out(
139-
KernelRuntimeContext& context,
139+
KernelRuntimeContext& ctx,
140140
const Tensor& self,
141141
const Tensor& mat2,
142142
Tensor& out) {
143-
(void)context;
143+
(void)ctx;
144144

145145
ET_KERNEL_CHECK(
146-
context,
146+
ctx,
147147
resize_out_tensor(self, mat2, out) == Error::Ok,
148148
InvalidArgument,
149149
out);
150150
ET_KERNEL_CHECK(
151-
context, check_bmm_out_args(self, mat2, out), InvalidArgument, out);
152-
153-
#define BMM_TENSOR(ctype, dtype) \
154-
case ScalarType::dtype: \
155-
bmm_kernel<ctype>(self, mat2, out); \
156-
break;
157-
158-
auto scalar_type = self.scalar_type();
159-
switch (scalar_type) {
160-
ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR)
161-
default:
162-
ET_CHECK_MSG(
163-
false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type));
151+
ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out);
152+
153+
constexpr auto name = "bmm.out";
154+
auto self_type = self.scalar_type();
155+
156+
if (executorch::runtime::isComplexType(self_type)) {
157+
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
158+
internal::bmm_out_impl<CTYPE>(self, mat2, out);
159+
});
160+
} else {
161+
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {
162+
bmm_kernel<CTYPE>(self, mat2, out);
163+
});
164164
}
165-
#undef BMM_TENSOR
166165

167166
return out;
168167
}

kernels/optimized/cpu/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ _OPTIMIZED_ATEN_OPS = (
1515
name = "op_bmm",
1616
deps = [
1717
"//executorch/kernels/optimized:libblas",
18+
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
1819
],
1920
),
2021
op_target(

kernels/portable/cpu/op_bmm.cpp

+11-19
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
10-
#include <executorch/kernels/portable/cpu/vec_ops.h>
1110
#include <executorch/runtime/kernel/kernel_includes.h>
1211

1312
namespace torch {
@@ -37,26 +36,19 @@ Tensor& bmm_out(
3736
InvalidArgument,
3837
out);
3938

40-
ET_SWITCH_REAL_TYPES_AND(
41-
Half, in.scalar_type(), ctx, "bmm.out", CTYPE, [&]() {
42-
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
43-
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
44-
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
39+
constexpr auto name = "bmm.out";
4540

46-
int64_t batch_size = in.size(0);
47-
int64_t m = in.size(1);
48-
int64_t n = in.size(2);
49-
int64_t p = mat2.size(2);
41+
auto in_type = in.scalar_type();
5042

51-
for (int i = 0; i < batch_size; ++i) {
52-
const CTYPE* in_data_offset = in_data + i * m * n;
53-
const CTYPE* mat2_data_offset = mat2_data + i * n * p;
54-
CTYPE* out_data_offset = out_data + i * m * p;
55-
56-
vec_matmul<CTYPE>(
57-
out_data_offset, in_data_offset, mat2_data_offset, m, n, p);
58-
}
59-
});
43+
if (executorch::runtime::isComplexType(in_type)) {
44+
ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, name, CTYPE, [&]() {
45+
internal::bmm_out_impl<CTYPE>(in, mat2, out);
46+
});
47+
} else {
48+
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
49+
internal::bmm_out_impl<CTYPE>(in, mat2, out);
50+
});
51+
}
6052

6153
return out;
6254
}

kernels/portable/cpu/util/matmul_ops_util.h

+31
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,36 @@ void get_linear_out_target_size(
4545
Tensor::SizesType* out_sizes,
4646
size_t* out_ndim);
4747

48+
namespace internal {
49+
50+
template <typename CTYPE>
51+
void bmm_out_impl(const Tensor& in, const Tensor& mat2, Tensor& out) {
52+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
53+
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
54+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
55+
56+
int64_t batch_size = in.size(0);
57+
int64_t m = in.size(1);
58+
int64_t n = in.size(2);
59+
int64_t p = mat2.size(2);
60+
61+
for (int b = 0; b < batch_size; ++b) {
62+
const CTYPE* in_data_offset = in_data + b * m * n;
63+
const CTYPE* mat2_data_offset = mat2_data + b * n * p;
64+
CTYPE* out_data_offset = out_data + b * m * p;
65+
66+
for (const auto i : c10::irange(m)) {
67+
for (const auto j : c10::irange(p)) {
68+
CTYPE sum = static_cast<CTYPE>(0.0);
69+
for (const auto k : c10::irange(n)) {
70+
sum += in_data_offset[i * n + k] * mat2_data_offset[k * p + j];
71+
}
72+
out_data_offset[i * p + j] = sum;
73+
}
74+
}
75+
}
76+
}
77+
78+
} // namespace internal
4879
} // namespace executor
4980
} // namespace torch

kernels/test/op_bmm_test.cpp

+66-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,61 @@ class OpBmmOutTest : public OperatorTest {
4343

4444
EXPECT_TENSOR_EQ(out, expected);
4545
}
46+
47+
template <typename CTYPE, ScalarType DTYPE>
48+
void test_complex_dtype() {
49+
TensorFactory<DTYPE> tf;
50+
Tensor x = tf.make(
51+
{2, 2, 3},
52+
{CTYPE(1, 1),
53+
CTYPE(2, 2),
54+
CTYPE(3, 3),
55+
CTYPE(4, 4),
56+
CTYPE(5, 5),
57+
CTYPE(6, 6),
58+
CTYPE(7, 7),
59+
CTYPE(8, 8),
60+
CTYPE(9, 9),
61+
CTYPE(10, 10),
62+
CTYPE(11, 11),
63+
CTYPE(12, 12)});
64+
Tensor y = tf.make(
65+
{2, 3, 2},
66+
{CTYPE(2, 1),
67+
CTYPE(4, 2),
68+
CTYPE(6, 3),
69+
CTYPE(8, 4),
70+
CTYPE(10, 5),
71+
CTYPE(12, 6),
72+
CTYPE(14, 7),
73+
CTYPE(16, 8),
74+
CTYPE(18, 9),
75+
CTYPE(20, 10),
76+
CTYPE(22, 11),
77+
CTYPE(24, 12)});
78+
Tensor out = tf.make(
79+
{2, 2, 2},
80+
{CTYPE(0, 0),
81+
CTYPE(0, 0),
82+
CTYPE(0, 0),
83+
CTYPE(0, 0),
84+
CTYPE(0, 0),
85+
CTYPE(0, 0),
86+
CTYPE(0, 0),
87+
CTYPE(0, 0)});
88+
Tensor expected = tf.make(
89+
{2, 2, 2},
90+
{CTYPE(22, 66),
91+
CTYPE(28, 84),
92+
CTYPE(49, 147),
93+
CTYPE(64, 192),
94+
CTYPE(220, 660),
95+
CTYPE(244, 732),
96+
CTYPE(301, 903),
97+
CTYPE(334, 1002)});
98+
op_bmm_out(x, y, out);
99+
EXPECT_TENSOR_CLOSE(out, expected);
100+
}
46101
};
47102

48103
TEST_F(OpBmmOutTest, OutputDim) {
@@ -132,7 +187,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) {
132187

133188
/// A generic smoke test that works for any dtype that supports ones() and
134189
/// zeros().
135-
TEST_F(OpBmmOutTest, AllDtypesSupported) {
190+
TEST_F(OpBmmOutTest, AllRealDtypesSupported) {
136191
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
137192
ET_FORALL_REAL_TYPES(TEST_ENTRY);
138193
#undef TEST_ENTRY
@@ -141,6 +196,16 @@ TEST_F(OpBmmOutTest, AllDtypesSupported) {
141196
// for those types.
142197
}
143198

199+
TEST_F(OpBmmOutTest, AllComplexDtypesSupported) {
200+
#define TEST_ENTRY(ctype, dtype) test_complex_dtype<ctype, ScalarType::dtype>();
201+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
202+
ET_FORALL_COMPLEX_TYPES(TEST_ENTRY);
203+
} else {
204+
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
205+
}
206+
#undef TEST_ENTRY
207+
}
208+
144209
TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) {
145210
TensorFactory<ScalarType::Int> tf;
146211

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ ATEN_OPS = (
372372
name = "op_bmm",
373373
deps = [
374374
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
375-
":vec_ops",
376375
],
377376
),
378377
op_target(

0 commit comments

Comments
 (0)