-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathnode_input_output.cpp
152 lines (119 loc) · 5.83 KB
/
node_input_output.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "openvino/op/add.hpp"
#include "openvino/op/parameter.hpp"
using namespace ov;
using namespace std;
TEST(node_input_output, input_create) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto add = make_shared<op::v1::Add>(x, y);
auto add_in_0 = add->input(0);
auto add_in_1 = add->input(1);
EXPECT_EQ(add_in_0.get_node(), add.get());
EXPECT_EQ(add_in_0.get_index(), 0);
EXPECT_EQ(add_in_0.get_element_type(), element::f32);
EXPECT_EQ(add_in_0.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_in_0.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_EQ(add_in_0.get_source_output(), Output<Node>(x, 0));
EXPECT_EQ(add_in_1.get_node(), add.get());
EXPECT_EQ(add_in_1.get_index(), 1);
EXPECT_EQ(add_in_1.get_element_type(), element::f32);
EXPECT_EQ(add_in_1.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_in_1.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_EQ(add_in_1.get_source_output(), Output<Node>(y, 0));
EXPECT_THROW(add->input(2), ov::Exception);
}
TEST(node_input_output, input_create_const) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto add = make_shared<const op::v1::Add>(x, y);
auto add_in_0 = add->input(0);
auto add_in_1 = add->input(1);
EXPECT_EQ(add_in_0.get_node(), add.get());
EXPECT_EQ(add_in_0.get_index(), 0);
EXPECT_EQ(add_in_0.get_element_type(), element::f32);
EXPECT_EQ(add_in_0.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_in_0.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_EQ(add_in_0.get_source_output(), Output<Node>(x, 0));
EXPECT_EQ(add_in_1.get_node(), add.get());
EXPECT_EQ(add_in_1.get_index(), 1);
EXPECT_EQ(add_in_1.get_element_type(), element::f32);
EXPECT_EQ(add_in_1.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_in_1.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_EQ(add_in_1.get_source_output(), Output<Node>(y, 0));
EXPECT_THROW(add->input(2), ov::Exception);
}
TEST(node_input_output, output_create) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto add = make_shared<op::v1::Add>(x, y);
auto add_out_0 = add->output(0);
add_out_0.set_names({"a", "b"});
EXPECT_EQ(add_out_0.get_names(), std::unordered_set<std::string>({"a", "b"}));
EXPECT_EQ(add_out_0.get_any_name(), "a");
add_out_0.add_names({"c", "d"});
EXPECT_EQ(add_out_0.get_names(), std::unordered_set<std::string>({"a", "b", "c", "d"}));
EXPECT_EQ(add_out_0.get_any_name(), "a");
EXPECT_EQ(add_out_0.get_node(), add.get());
EXPECT_EQ(add_out_0.get_index(), 0);
EXPECT_EQ(add_out_0.get_element_type(), element::f32);
EXPECT_EQ(add_out_0.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_out_0.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_THROW(add->output(1), ov::Exception);
}
TEST(node_input_output, output_create_const) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto add = make_shared<const op::v1::Add>(x, y);
auto add_out_0 = add->output(0);
EXPECT_EQ(add_out_0.get_names().size(), 0);
EXPECT_EQ(add_out_0.get_node(), add.get());
EXPECT_EQ(add_out_0.get_index(), 0);
EXPECT_EQ(add_out_0.get_element_type(), element::f32);
EXPECT_EQ(add_out_0.get_shape(), (Shape{1, 2, 3, 4}));
EXPECT_TRUE(add_out_0.get_partial_shape().same_scheme(PartialShape{1, 2, 3, 4}));
EXPECT_THROW(add->output(1), ov::Exception);
}
TEST(node_input_output, output_rt_info) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto add = make_shared<op::v1::Add>(x, y);
auto add_const = make_shared<const op::v1::Add>(x, y);
Output<Node> output = add->output(0);
Output<const Node> output_const = add_const->output(0);
auto& rt = output.get_rt_info();
rt["test"] = nullptr;
EXPECT_TRUE(output.get_rt_info().count("test"));
EXPECT_TRUE(output.get_tensor_ptr()->get_rt_info().count("test"));
EXPECT_TRUE(output_const.get_rt_info().empty());
}
TEST(node_input_output, input_set_argument) {
auto x = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1});
auto y = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});
auto z = make_shared<ov::op::v0::Parameter>(element::f32, Shape{3});
auto add = make_shared<op::v1::Add>(x, y);
EXPECT_EQ(add->get_input_size(), 2);
EXPECT_EQ(add->input(0).get_shape(), Shape{1});
EXPECT_EQ(add->input(1).get_shape(), Shape{2});
add->set_argument(1, z);
EXPECT_EQ(add->get_input_size(), 2);
EXPECT_EQ(add->input(0).get_shape(), Shape{1});
EXPECT_EQ(add->input(1).get_shape(), Shape{3});
add->set_arguments(NodeVector{z, x});
EXPECT_EQ(add->get_input_size(), 2);
EXPECT_EQ(add->input(0).get_shape(), Shape{3});
EXPECT_EQ(add->input(1).get_shape(), Shape{1});
}
TEST(node_input_output, create_wrong_input_output) {
EXPECT_THROW(ov::Output<ov::Node>(nullptr, 0), ov::Exception);
EXPECT_THROW(ov::Output<const ov::Node>(nullptr, 0), ov::Exception);
EXPECT_THROW(ov::Input<ov::Node>(nullptr, 0), ov::Exception);
EXPECT_THROW(ov::Input<const ov::Node>(nullptr, 0), ov::Exception);
}