Skip to content

Commit 2b3ee97

Browse files
pssrawatkeyprocedure
authored andcommitted
Add view_as_real_copy.out
Differential Revision: D72294238 Pull Request resolved: pytorch#10207
1 parent 8b6ef57 commit 2b3ee97

File tree

9 files changed

+194
-0
lines changed

9 files changed

+194
-0
lines changed

kernels/aten/functions.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@
423423

424424
- op: var.out
425425

426+
- op: view_as_real_copy.out
427+
426428
- op: view_copy.out
427429

428430
- op: where.self_out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <executorch/runtime/platform/assert.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using Tensor = executorch::aten::Tensor;
18+
19+
namespace {
20+
21+
template <typename SELF_CTYPE, typename OUT_CTYPE>
22+
inline void _to_impl(const Tensor& self, Tensor& out) {
23+
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
24+
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
25+
26+
for (size_t i = 0, e = self.numel(); i < e; i++) {
27+
auto val_in = self_data[i];
28+
out_data[2 * i] = static_cast<OUT_CTYPE>(val_in.real_);
29+
out_data[2 * i + 1] = static_cast<OUT_CTYPE>(val_in.imag_);
30+
}
31+
}
32+
33+
} // namespace
34+
35+
// view_as_real_copy(Tensor self) -> Tensor
36+
Tensor& view_as_real_copy_out(
37+
KernelRuntimeContext& ctx,
38+
const Tensor& self,
39+
Tensor& out) {
40+
(void)ctx;
41+
42+
// Get the output shape
43+
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
44+
get_view_as_real_copy_out_target_size(self, expected_output_size);
45+
46+
// Resize for dynamic shape
47+
ET_KERNEL_CHECK_MSG(
48+
ctx,
49+
resize_tensor(
50+
out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
51+
Error::Ok,
52+
InvalidArgument,
53+
out,
54+
"Failed to resize output tensor.");
55+
56+
// The input tensor must be complex type
57+
ET_KERNEL_CHECK_MSG(
58+
ctx,
59+
executorch::runtime::isComplexType(self.scalar_type()),
60+
InvalidArgument,
61+
out,
62+
"Input tensor must be complex type");
63+
64+
ET_KERNEL_CHECK(
65+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
66+
67+
constexpr auto op_name = "view_as_real_copy.out";
68+
69+
ET_SWITCH_COMPLEXH_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
70+
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
71+
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
72+
});
73+
});
74+
75+
return out;
76+
}
77+
78+
} // namespace native
79+
} // namespace executor
80+
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1018,5 +1018,14 @@ void get_unfold_copy_out_target_size(
10181018
*out_ndim = self.dim() + 1;
10191019
}
10201020

1021+
void get_view_as_real_copy_out_target_size(
1022+
const Tensor& self,
1023+
executorch::aten::SizesType* out_sizes) {
1024+
for (auto i : c10::irange(self.dim())) {
1025+
out_sizes[i] = self.size(i);
1026+
}
1027+
out_sizes[self.dim()] = 2;
1028+
}
1029+
10211030
} // namespace executor
10221031
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.h

+4
Original file line numberDiff line numberDiff line change
@@ -247,5 +247,9 @@ void get_unfold_copy_out_target_size(
247247
executorch::aten::SizesType* out_sizes,
248248
size_t* out_ndim);
249249

250+
void get_view_as_real_copy_out_target_size(
251+
const Tensor& self,
252+
executorch::aten::SizesType* out_sizes);
253+
250254
} // namespace executor
251255
} // namespace torch

kernels/portable/functions.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,11 @@
957957
- arg_meta: null
958958
kernel_name: torch::executor::var_out
959959

960+
- op: view_as_real_copy.out
961+
kernels:
962+
- arg_meta: null
963+
kernel_name: torch::executor::view_as_real_copy_out
964+
960965
- op: view_copy.out
961966
kernels:
962967
- arg_meta: null

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ set(all_test_sources
242242
"op_upsample_bilinear2d_test.cpp"
243243
"op_upsample_nearest2d_test.cpp"
244244
"op_var_test.cpp"
245+
"op_view_as_real_copy_test.cpp"
245246
"op_view_copy_test.cpp"
246247
"op_where_test.cpp"
247248
"op_zeros_test.cpp"
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using torch::executor::testing::TensorFactory;
21+
22+
class OpViewAsRealTest : public OperatorTest {
23+
protected:
24+
Tensor& view_as_real_copy_out(const Tensor& self, Tensor& out) {
25+
return torch::executor::aten::view_as_real_copy_outf(context_, self, out);
26+
}
27+
28+
template <typename CTYPE, ScalarType DTYPE>
29+
void run_complex_smoke_test() {
30+
TensorFactory<DTYPE> tf;
31+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
32+
TensorFactory<REAL_DTYPE> tf_out;
33+
34+
Tensor in = tf.make(
35+
{2, 2},
36+
{CTYPE(3, 4), CTYPE(-1.7, 7.4), CTYPE(5, -12), CTYPE(8.3, 0.1)});
37+
Tensor out = tf_out.zeros({2, 2, 2});
38+
Tensor expected =
39+
tf_out.make({2, 2, 2}, {3, 4, -1.7, 7.4, 5, -12, 8.3, 0.1});
40+
Tensor ret = view_as_real_copy_out(in, out);
41+
42+
EXPECT_TENSOR_EQ(out, ret);
43+
EXPECT_TENSOR_EQ(out, expected);
44+
}
45+
46+
// Tests on tensors with 0 size
47+
template <typename CTYPE, ScalarType DTYPE>
48+
void test_empty_input() {
49+
TensorFactory<DTYPE> tf;
50+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
51+
TensorFactory<REAL_DTYPE> tf_out;
52+
53+
Tensor in = tf.make(/*sizes=*/{3, 0, 4}, /*data=*/{});
54+
Tensor out = tf_out.zeros({3, 0, 4, 2});
55+
Tensor expected = tf_out.make(/*sizes=*/{3, 0, 4, 2}, /*data=*/{});
56+
Tensor ret = view_as_real_copy_out(in, out);
57+
58+
EXPECT_TENSOR_EQ(out, ret);
59+
EXPECT_TENSOR_EQ(out, expected);
60+
}
61+
62+
// Tests on 0-dim input tensors
63+
template <typename CTYPE, ScalarType DTYPE>
64+
void zero_dim_input() {
65+
TensorFactory<DTYPE> tf;
66+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
67+
TensorFactory<REAL_DTYPE> tf_out;
68+
69+
Tensor in = tf.make(/*sizes=*/{}, {CTYPE(0, 0)});
70+
Tensor out = tf_out.zeros({2});
71+
Tensor expected = tf_out.zeros(/*sizes=*/{2});
72+
Tensor ret = view_as_real_copy_out(in, out);
73+
74+
EXPECT_TENSOR_EQ(out, ret);
75+
EXPECT_TENSOR_EQ(out, expected);
76+
}
77+
};
78+
79+
TEST_F(OpViewAsRealTest, ComplexSmokeTest) {
80+
#define RUN_SMOKE_TEST(ctype, dtype) \
81+
run_complex_smoke_test<ctype, ScalarType::dtype>(); \
82+
test_empty_input<ctype, ScalarType::dtype>(); \
83+
zero_dim_input<ctype, ScalarType::dtype>();
84+
ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);
85+
#undef RUN_SMOKE_TEST
86+
}

kernels/test/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def define_common_targets():
331331
_common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"])
332332
_common_op_test("op_upsample_nearest2d_test", ["aten", "portable"])
333333
_common_op_test("op_var_test", ["aten", "portable"])
334+
_common_op_test("op_view_as_real_copy_test", ["aten", "portable"])
334335
_common_op_test("op_view_copy_test", ["aten", "portable"])
335336
_common_op_test("op_where_test", ["aten", "portable"])
336337
_common_op_test("op_zeros_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

+6
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,12 @@ ATEN_OPS = (
12681268
"//executorch/kernels/portable/cpu/util:reduce_util",
12691269
],
12701270
),
1271+
op_target(
1272+
name = "op_view_as_real_copy",
1273+
deps = [
1274+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
1275+
],
1276+
),
12711277
op_target(
12721278
name = "op_view_copy",
12731279
deps = [

0 commit comments

Comments
 (0)