21
21
multi_return_metadata ,
22
22
MultiReturn ,
23
23
resnet18 ,
24
+ resnet18_dynamo ,
24
25
Simple ,
25
26
)
26
27
except ImportError :
30
31
multi_return_metadata ,
31
32
MultiReturn ,
32
33
resnet18 ,
34
+ resnet18_dynamo ,
33
35
Simple ,
34
36
)
35
37
@@ -60,34 +62,39 @@ def save(
60
62
name ,
61
63
model ,
62
64
model_jit = None ,
65
+ model_dynamo = None ,
63
66
eg = None ,
64
67
featurestore_meta = None ,
65
68
text_in_extra_file = None ,
66
69
binary_in_extra_file = None ,
67
70
):
68
- with PackageExporter (str (p / name )) as e :
69
- e .mock ("iopath.**" )
70
- e .intern ("**" )
71
- e .save_pickle ("model" , "model.pkl" , model )
72
- if eg :
73
- e .save_pickle ("model" , "example.pkl" , eg )
74
- if featurestore_meta :
75
- # TODO(whc) can this name come from buck somehow,
76
- # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
77
- e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
78
- if text_in_extra_file :
79
- e .save_text ("extra_files" , "text" , text_in_extra_file )
80
- if binary_in_extra_file :
81
- e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
82
-
71
+ def package_model (name , model ):
72
+ with PackageExporter (str (p / name )) as e :
73
+ e .mock ("iopath.**" )
74
+ e .intern ("**" )
75
+ e .save_pickle ("model" , "model.pkl" , model )
76
+ if eg :
77
+ e .save_pickle ("model" , "example.pkl" , eg )
78
+ if featurestore_meta :
79
+ # TODO(whc) can this name come from buck somehow,
80
+ # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
81
+ e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
82
+ if text_in_extra_file :
83
+ e .save_text ("extra_files" , "text" , text_in_extra_file )
84
+ if binary_in_extra_file :
85
+ e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
86
+
87
+ package_model (name , model )
88
+ if model_dynamo :
89
+ package_model (name + "_dynamo" , model_dynamo )
83
90
if model_jit :
84
91
model_jit .save (str (p / (name + "_jit" )))
92
+
85
93
86
94
87
95
parser = argparse .ArgumentParser (description = "Generate Examples" )
88
96
parser .add_argument ("--install_dir" , help = "Root directory for all output files" )
89
97
90
-
91
98
if __name__ == "__main__" :
92
99
args = parser .parse_args ()
93
100
if args .install_dir is None :
@@ -98,9 +105,10 @@ def save(
98
105
99
106
resnet = resnet18 ()
100
107
resnet .eval ()
108
+ resnet_dynamo = resnet18_dynamo ()
101
109
resnet_eg = torch .rand (1 , 3 , 224 , 224 )
102
110
resnet_traced = torch .jit .trace (resnet , resnet_eg )
103
- save ("resnet" , resnet , resnet_traced , (resnet_eg ,))
111
+ save ("resnet" , resnet , resnet_traced , resnet_dynamo , (resnet_eg ,))
104
112
105
113
simple = Simple (10 , 20 )
106
114
save (
@@ -117,6 +125,7 @@ def save(
117
125
"multi_return" ,
118
126
multi_return ,
119
127
torch .jit .script (multi_return ),
128
+ None ,
120
129
(torch .rand (10 , 20 ),),
121
130
multi_return_metadata ,
122
131
)
@@ -149,4 +158,4 @@ def save(
149
158
e .add_dependency ("tensorrt" )
150
159
e .mock ("iopath.**" )
151
160
e .intern ("**" )
152
- e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
161
+ e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
0 commit comments