@@ -26,26 +26,35 @@ using torch::executor::util::FileDataLoader;
26
26
27
27
class MethodMetaTest : public ::testing::Test {
28
28
protected:
29
- void SetUp () override {
30
- // Create a loader for the serialized ModuleAdd program.
31
- const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
29
+ void load_program (const char * path, const char * module_name) {
30
+ // Create a loader for the serialized program.
32
31
Result<FileDataLoader> loader = FileDataLoader::from (path);
33
32
ASSERT_EQ (loader.error (), Error::Ok);
34
- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
33
+ loaders_.insert (
34
+ {module_name,
35
+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
35
36
36
37
// Use it to load the program.
37
38
Result<Program> program = Program::load (
38
- loader_.get (), Program::Verification::InternalConsistency);
39
+ loaders_[module_name].get (),
40
+ Program::Verification::InternalConsistency);
39
41
ASSERT_EQ (program.error (), Error::Ok);
40
- program_ = std::make_unique<Program>(std::move (program.get ()));
42
+ programs_.insert (
43
+ {module_name, std::make_unique<Program>(std::move (program.get ()))});
44
+ }
45
+
46
+ void SetUp () override {
47
+
48
+ load_program (std::getenv (" ET_MODULE_ADD_PATH" ), " add" );
49
+ load_program (std::getenv (" ET_MODULE_STATEFUL_PATH" ), " stateful" );
41
50
}
42
51
43
52
private:
44
53
// Must outlive program_, but tests shouldn't need to touch it.
45
- std::unique_ptr<FileDataLoader> loader_ ;
54
+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
46
55
47
56
protected:
48
- std::unique_ptr<Program> program_ ;
57
+ std::unordered_map<std::string, std:: unique_ptr<Program>> programs_ ;
49
58
};
50
59
51
60
namespace {
@@ -67,7 +76,7 @@ void check_tensor(const TensorInfo& tensor_info) {
67
76
} // namespace
68
77
69
78
TEST_F (MethodMetaTest, MethodMetaApi) {
70
- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
79
+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
71
80
ASSERT_EQ (method_meta.error (), Error::Ok);
72
81
73
82
// Appropriate amount of inputs
@@ -97,11 +106,11 @@ TEST_F(MethodMetaTest, MethodMetaApi) {
97
106
98
107
// Missing method fails
99
108
EXPECT_EQ (
100
- program_ ->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
109
+ programs_[ " add " ] ->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
101
110
}
102
111
103
112
TEST_F (MethodMetaTest, TensorInfoApi) {
104
- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
113
+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
105
114
ASSERT_EQ (method_meta.error (), Error::Ok);
106
115
107
116
// Input 1
@@ -138,3 +147,18 @@ TEST_F(MethodMetaTest, TensorInfoApi) {
138
147
EXPECT_EQ (
139
148
method_meta->output_tensor_meta (-1 ).error (), Error::InvalidArgument);
140
149
}
150
+
151
+ TEST_F (MethodMetaTest, MethodMetaAttribute) {
152
+ Result<MethodMeta> method_meta = programs_[" stateful" ]->method_meta (" forward" );
153
+ ASSERT_EQ (method_meta.error (), Error::Ok);
154
+
155
+ ASSERT_EQ (method_meta->num_attributes (), 1 );
156
+ auto state = method_meta->attribute_tensor_meta (0 );
157
+ ASSERT_TRUE (state.ok ());
158
+
159
+ ASSERT_EQ (state->name (), " state" );
160
+ ASSERT_FALSE (state->is_memory_planned ());
161
+
162
+ auto bad_access = method_meta->attribute_tensor_meta (1 );
163
+ ASSERT_EQ (bad_access.error (), Error::InvalidArgument);
164
+ }
0 commit comments