8
8
from torch .jit .mobile import (
9
9
_load_for_lite_interpreter ,
10
10
_get_model_bytecode_version ,
11
+ _get_model_ops_and_info ,
11
12
_backport_for_mobile_to_buffer ,
12
13
_backport_for_mobile )
13
14
from torch .testing ._internal .common_utils import TestCase , run_tests
14
15
from pathlib import Path
15
16
16
- pytorch_test_dri = Path (__file__ ).resolve ().parents [1 ]
17
+ pytorch_test_dir = Path (__file__ ).resolve ().parents [1 ]
17
18
18
19
# script_module_v4.ptl and script_module_v5.ptl source code
19
20
# class TestModule(torch.nn.Module):
97
98
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
98
99
'''
99
100
101
+ SCRIPT_MODULE_V6_BYTECODE_PKL = '''
102
+ (6,
103
+ ('__torch__.*.TestModule.forward',
104
+ (('instructions',
105
+ (('STOREN', 1, 2),
106
+ ('DROPR', 1, 0),
107
+ ('LOADC', 0, 0),
108
+ ('LOADC', 1, 0),
109
+ ('MOVE', 2, 0),
110
+ ('OP', 0, 0),
111
+ ('OP', 1, 0),
112
+ ('RET', 0, 0))),
113
+ ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
114
+ ('constants',
115
+ (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
116
+ 0,
117
+ (2, 4),
118
+ (4, 1),
119
+ False,
120
+ collections.OrderedDict()),
121
+ 1)),
122
+ ('types', ()),
123
+ ('register_size', 2)),
124
+ (('arguments',
125
+ ((('name', 'self'),
126
+ ('type', '__torch__.*.TestModule'),
127
+ ('default_value', None)),
128
+ (('name', 'y'), ('type', 'int'), ('default_value', None)))),
129
+ ('returns',
130
+ ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
131
+ '''
132
+
100
133
SCRIPT_MODULE_BYTECODE_PKL = {
101
134
4 : {
102
135
"bytecode_pkl" : SCRIPT_MODULE_V4_BYTECODE_PKL ,
@@ -113,7 +146,7 @@ def check_model_version(model_path, expect_version):
113
146
actual_version = _get_model_bytecode_version (model_path )
114
147
assert (actual_version == expect_version )
115
148
for version , model_info in SCRIPT_MODULE_BYTECODE_PKL .items ():
116
- model_path = pytorch_test_dri / "cpp" / "jit" / model_info ["model_name" ]
149
+ model_path = pytorch_test_dir / "cpp" / "jit" / model_info ["model_name" ]
117
150
check_model_version (model_path , version )
118
151
119
152
def test_bytecode_values_for_all_backport_functions (self ):
@@ -130,7 +163,7 @@ def test_bytecode_values_for_all_backport_functions(self):
130
163
while current_from_version > MINIMUM_TO_VERSION :
131
164
# Load model v5 and run forward method
132
165
model_name = SCRIPT_MODULE_BYTECODE_PKL [current_from_version ]["model_name" ]
133
- input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name
166
+ input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name
134
167
135
168
# A temporary model file will be export to this path, and run through bytecode.pkl
136
169
# content check.
@@ -205,7 +238,7 @@ def forward(self, y: int):
205
238
# Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
206
239
def test_backport_bytecode_from_file_to_file (self ):
207
240
maximum_checked_in_model_version = max (SCRIPT_MODULE_BYTECODE_PKL .keys ())
208
- script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL [
241
+ script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL [
209
242
maximum_checked_in_model_version ]["model_name" ]
210
243
211
244
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION ):
@@ -241,7 +274,7 @@ def test_backport_bytecode_from_file_to_file(self):
241
274
# Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
242
275
def test_backport_bytecode_from_file_to_buffer (self ):
243
276
maximum_checked_in_model_version = max (SCRIPT_MODULE_BYTECODE_PKL .keys ())
244
- script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL [
277
+ script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL [
245
278
maximum_checked_in_model_version ]["model_name" ]
246
279
247
280
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION ):
@@ -264,5 +297,12 @@ def test_backport_bytecode_from_file_to_buffer(self):
264
297
torch .testing .assert_allclose (mobile_module_result , expected_mobile_module_result )
265
298
266
299
300
+ def test_get_model_ops_and_info (self ):
301
+ # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
302
+ script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
303
+ ops_v6 = _get_model_ops_and_info (script_module_v6 )
304
+ assert (ops_v6 ["aten::add.int" ].num_schema_args == 2 )
305
+ assert (ops_v6 ["aten::add.Scalar" ].num_schema_args == 2 )
306
+
267
307
if __name__ == '__main__' :
268
308
run_tests ()
0 commit comments