21
21
multi_return_metadata ,
22
22
MultiReturn ,
23
23
resnet18 ,
24
+ resnet18_dynamo ,
24
25
Simple ,
25
26
)
26
27
except ImportError :
30
31
multi_return_metadata ,
31
32
MultiReturn ,
32
33
resnet18 ,
34
+ resnet18_dynamo ,
33
35
Simple ,
34
36
)
35
37
@@ -60,34 +62,39 @@ def save(
60
62
name ,
61
63
model ,
62
64
model_jit = None ,
65
+ model_dynamo = None ,
63
66
eg = None ,
64
67
featurestore_meta = None ,
65
68
text_in_extra_file = None ,
66
69
binary_in_extra_file = None ,
67
70
):
68
- with PackageExporter (str (p / name )) as e :
69
- e .mock ("iopath.**" )
70
- e .intern ("**" )
71
- e .save_pickle ("model" , "model.pkl" , model )
72
- if eg :
73
- e .save_pickle ("model" , "example.pkl" , eg )
74
- if featurestore_meta :
75
- # TODO(whc) can this name come from buck somehow,
76
- # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
77
- e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
78
- if text_in_extra_file :
79
- e .save_text ("extra_files" , "text" , text_in_extra_file )
80
- if binary_in_extra_file :
81
- e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
82
-
71
+ def package_model (name , model ):
72
+ with PackageExporter (str (p / name )) as e :
73
+ e .mock ("iopath.**" )
74
+ e .intern ("**" )
75
+ e .save_pickle ("model" , "model.pkl" , model )
76
+ if eg :
77
+ e .save_pickle ("model" , "example.pkl" , eg )
78
+ if featurestore_meta :
79
+ # TODO(whc) can this name come from buck somehow,
80
+ # so it's consistent with predictor_config_constants::METADATA_FILE_NAME()?
81
+ e .save_text ("extra_files" , "metadata.json" , featurestore_meta )
82
+ if text_in_extra_file :
83
+ e .save_text ("extra_files" , "text" , text_in_extra_file )
84
+ if binary_in_extra_file :
85
+ e .save_binary ("extra_files" , "binary" , binary_in_extra_file )
86
+
87
+ package_model (name , model )
88
+ if model_dynamo :
89
+ package_model (name + "_dynamo" , model_dynamo )
83
90
if model_jit :
84
91
model_jit .save (str (p / (name + "_jit" )))
92
+
85
93
86
94
87
95
parser = argparse .ArgumentParser (description = "Generate Examples" )
88
96
parser .add_argument ("--install_dir" , help = "Root directory for all output files" )
89
97
90
-
91
98
if __name__ == "__main__" :
92
99
args = parser .parse_args ()
93
100
if args .install_dir is None :
@@ -98,9 +105,11 @@ def save(
98
105
99
106
resnet = resnet18 ()
100
107
resnet .eval ()
108
+ resnet_dynamo = resnet18_dynamo ()
101
109
resnet_eg = torch .rand (1 , 3 , 224 , 224 )
102
110
resnet_traced = torch .jit .trace (resnet , resnet_eg )
103
- save ("resnet" , resnet , resnet_traced , (resnet_eg ,))
111
+ save ("resnet" , resnet_dynamo , resnet_traced , resnet_dynamo , (resnet_eg ,))
112
+ # save("resnet", resnet, resnet_traced, resnet_dynamo, (resnet_eg,))
104
113
105
114
simple = Simple (10 , 20 )
106
115
save (
@@ -117,6 +126,7 @@ def save(
117
126
"multi_return" ,
118
127
multi_return ,
119
128
torch .jit .script (multi_return ),
129
+ None ,
120
130
(torch .rand (10 , 20 ),),
121
131
multi_return_metadata ,
122
132
)
@@ -149,4 +159,4 @@ def save(
149
159
e .add_dependency ("tensorrt" )
150
160
e .mock ("iopath.**" )
151
161
e .intern ("**" )
152
- e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
162
+ e .save_pickle ("make_trt_module" , "model.pkl" , make_trt_module )
0 commit comments