@@ -26,26 +26,34 @@ using torch::executor::util::FileDataLoader;
26
26
27
27
class MethodMetaTest : public ::testing::Test {
28
28
protected:
29
- void SetUp () override {
30
- // Create a loader for the serialized ModuleAdd program.
31
- const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
29
+ void load_program (const char * path, const char * module_name) {
30
+ // Create a loader for the serialized program.
32
31
Result<FileDataLoader> loader = FileDataLoader::from (path);
33
32
ASSERT_EQ (loader.error (), Error::Ok);
34
- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
33
+ loaders_.insert (
34
+ {module_name,
35
+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
35
36
36
37
// Use it to load the program.
37
38
Result<Program> program = Program::load (
38
- loader_.get (), Program::Verification::InternalConsistency);
39
+ loaders_[module_name].get (),
40
+ Program::Verification::InternalConsistency);
39
41
ASSERT_EQ (program.error (), Error::Ok);
40
- program_ = std::make_unique<Program>(std::move (program.get ()));
42
+ programs_.insert (
43
+ {module_name, std::make_unique<Program>(std::move (program.get ()))});
44
+ }
45
+
46
+ void SetUp () override {
47
+ load_program (std::getenv (" ET_MODULE_ADD_PATH" ), " add" );
48
+ load_program (std::getenv (" ET_MODULE_STATEFUL_PATH" ), " stateful" );
41
49
}
42
50
43
51
private:
44
52
// Must outlive program_, but tests shouldn't need to touch it.
45
- std::unique_ptr<FileDataLoader> loader_ ;
53
+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
46
54
47
55
protected:
48
- std::unique_ptr<Program> program_ ;
56
+ std::unordered_map<std::string, std:: unique_ptr<Program>> programs_ ;
49
57
};
50
58
51
59
namespace {
@@ -67,7 +75,7 @@ void check_tensor(const TensorInfo& tensor_info) {
67
75
} // namespace
68
76
69
77
TEST_F (MethodMetaTest, MethodMetaApi) {
70
- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
78
+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
71
79
ASSERT_EQ (method_meta.error (), Error::Ok);
72
80
73
81
// Appropriate amount of inputs
@@ -97,11 +105,12 @@ TEST_F(MethodMetaTest, MethodMetaApi) {
97
105
98
106
// Missing method fails
99
107
EXPECT_EQ (
100
- program_->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
108
+ programs_[" add" ]->method_meta (" not_a_method" ).error (),
109
+ Error::InvalidArgument);
101
110
}
102
111
103
112
TEST_F (MethodMetaTest, TensorInfoApi) {
104
- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
113
+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
105
114
ASSERT_EQ (method_meta.error (), Error::Ok);
106
115
107
116
// Input 1
@@ -138,3 +147,19 @@ TEST_F(MethodMetaTest, TensorInfoApi) {
138
147
EXPECT_EQ (
139
148
method_meta->output_tensor_meta (-1 ).error (), Error::InvalidArgument);
140
149
}
150
+
151
+ TEST_F (MethodMetaTest, MethodMetaAttribute) {
152
+ Result<MethodMeta> method_meta =
153
+ programs_[" stateful" ]->method_meta (" forward" );
154
+ ASSERT_EQ (method_meta.error (), Error::Ok);
155
+
156
+ ASSERT_EQ (method_meta->num_attributes (), 1 );
157
+ auto state = method_meta->attribute_tensor_meta (0 );
158
+ ASSERT_TRUE (state.ok ());
159
+
160
+ ASSERT_EQ (state->name (), " state" );
161
+ ASSERT_FALSE (state->is_memory_planned ());
162
+
163
+ auto bad_access = method_meta->attribute_tensor_meta (1 );
164
+ ASSERT_EQ (bad_access.error (), Error::InvalidArgument);
165
+ }
0 commit comments