14
14
from torch .autograd import Variable
15
15
from torch .utils .serialization import load_lua
16
16
17
+
17
18
class LambdaBase (nn .Sequential ):
18
19
def __init__ (self , fn , * args ):
19
20
super (LambdaBase , self ).__init__ (* args )
@@ -25,134 +26,139 @@ def forward_prepare(self, input):
25
26
output .append (module (input ))
26
27
return output if output else input
27
28
29
+
28
30
class Lambda (LambdaBase ):
29
31
def forward (self , input ):
30
32
return self .lambda_func (self .forward_prepare (input ))
31
33
34
+
32
35
class LambdaMap (LambdaBase ):
33
36
def forward (self , input ):
34
37
# result is Variables list [Variable1, Variable2, ...]
35
- return list (map (self .lambda_func ,self .forward_prepare (input )))
38
+ return list (map (self .lambda_func , self .forward_prepare (input )))
39
+
36
40
37
41
class LambdaReduce (LambdaBase ):
38
42
def forward (self , input ):
39
43
# result is a Variable
40
- return reduce (self .lambda_func ,self .forward_prepare (input ))
44
+ return reduce (self .lambda_func , self .forward_prepare (input ))
41
45
42
46
43
- def copy_param (m ,n ):
47
+ def copy_param (m , n ):
44
48
if m .weight is not None : n .weight .data .copy_ (m .weight )
45
49
if m .bias is not None : n .bias .data .copy_ (m .bias )
46
- if hasattr (n ,'running_mean' ): n .running_mean .copy_ (m .running_mean )
47
- if hasattr (n ,'running_var' ): n .running_var .copy_ (m .running_var )
50
+ if hasattr (n , 'running_mean' ): n .running_mean .copy_ (m .running_mean )
51
+ if hasattr (n , 'running_var' ): n .running_var .copy_ (m .running_var )
52
+
48
53
49
54
def add_submodule (seq , * args ):
50
55
for n in args :
51
- seq .add_module (str (len (seq ._modules )),n )
56
+ seq .add_module (str (len (seq ._modules )), n )
57
+
52
58
53
- def lua_recursive_model (module ,seq ):
59
+ def lua_recursive_model (module , seq ):
54
60
for m in module .modules :
55
61
name = type (m ).__name__
56
62
real = m
57
63
if name == 'TorchObject' :
58
- name = m ._typename .replace ('cudnn.' ,'' )
64
+ name = m ._typename .replace ('cudnn.' , '' )
59
65
m = m ._obj
60
66
61
67
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM' :
62
- if not hasattr (m ,'groups' ) or m .groups is None : m .groups = 1
63
- n = nn .Conv2d (m .nInputPlane ,m .nOutputPlane ,(m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),1 , m .groups ,bias = (m .bias is not None ))
64
- copy_param (m ,n )
65
- add_submodule (seq ,n )
68
+ if not hasattr (m , 'groups' ) or m .groups is None : m .groups = 1
69
+ n = nn .Conv2d (m .nInputPlane , m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups , bias = (m .bias is not None ))
70
+ copy_param (m , n )
71
+ add_submodule (seq , n )
66
72
elif name == 'SpatialBatchNormalization' :
67
73
n = nn .BatchNorm2d (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )
68
- copy_param (m ,n )
69
- add_submodule (seq ,n )
74
+ copy_param (m , n )
75
+ add_submodule (seq , n )
70
76
elif name == 'VolumetricBatchNormalization' :
71
77
n = nn .BatchNorm3d (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )
72
78
copy_param (m , n )
73
79
add_submodule (seq , n )
74
80
elif name == 'ReLU' :
75
81
n = nn .ReLU ()
76
- add_submodule (seq ,n )
82
+ add_submodule (seq , n )
77
83
elif name == 'Sigmoid' :
78
84
n = nn .Sigmoid ()
79
- add_submodule (seq ,n )
85
+ add_submodule (seq , n )
80
86
elif name == 'SpatialMaxPooling' :
81
- n = nn .MaxPool2d ((m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),ceil_mode = m .ceil_mode )
82
- add_submodule (seq ,n )
87
+ n = nn .MaxPool2d ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), ceil_mode = m .ceil_mode )
88
+ add_submodule (seq , n )
83
89
elif name == 'SpatialAveragePooling' :
84
- n = nn .AvgPool2d ((m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),ceil_mode = m .ceil_mode )
85
- add_submodule (seq ,n )
90
+ n = nn .AvgPool2d ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), ceil_mode = m .ceil_mode )
91
+ add_submodule (seq , n )
86
92
elif name == 'SpatialUpSamplingNearest' :
87
93
n = nn .UpsamplingNearest2d (scale_factor = m .scale_factor )
88
- add_submodule (seq ,n )
94
+ add_submodule (seq , n )
89
95
elif name == 'View' :
90
- n = Lambda (lambda x : x .view (x .size (0 ),- 1 ))
91
- add_submodule (seq ,n )
96
+ n = Lambda (lambda x : x .view (x .size (0 ), - 1 ))
97
+ add_submodule (seq , n )
92
98
elif name == 'Reshape' :
93
- n = Lambda (lambda x : x .view (x .size (0 ),- 1 ))
94
- add_submodule (seq ,n )
99
+ n = Lambda (lambda x : x .view (x .size (0 ), - 1 ))
100
+ add_submodule (seq , n )
95
101
elif name == 'Linear' :
96
102
# Linear in pytorch only accept 2D input
97
- n1 = Lambda (lambda x : x .view (1 ,- 1 ) if 1 == len (x .size ()) else x )
98
- n2 = nn .Linear (m .weight .size (1 ),m .weight .size (0 ),bias = (m .bias is not None ))
99
- copy_param (m ,n2 )
100
- n = nn .Sequential (n1 ,n2 )
101
- add_submodule (seq ,n )
103
+ n1 = Lambda (lambda x : x .view (1 , - 1 ) if 1 == len (x .size ()) else x )
104
+ n2 = nn .Linear (m .weight .size (1 ), m .weight .size (0 ), bias = (m .bias is not None ))
105
+ copy_param (m , n2 )
106
+ n = nn .Sequential (n1 , n2 )
107
+ add_submodule (seq , n )
102
108
elif name == 'Dropout' :
103
109
m .inplace = False
104
110
n = nn .Dropout (m .p )
105
- add_submodule (seq ,n )
111
+ add_submodule (seq , n )
106
112
elif name == 'SoftMax' :
107
113
n = nn .Softmax ()
108
- add_submodule (seq ,n )
114
+ add_submodule (seq , n )
109
115
elif name == 'Identity' :
110
- n = Lambda (lambda x : x ) # do nothing
111
- add_submodule (seq ,n )
116
+ n = Lambda (lambda x : x ) # do nothing
117
+ add_submodule (seq , n )
112
118
elif name == 'SpatialFullConvolution' :
113
- n = nn .ConvTranspose2d (m .nInputPlane ,m .nOutputPlane ,(m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),(m .adjW ,m .adjH ))
114
- copy_param (m ,n )
115
- add_submodule (seq ,n )
119
+ n = nn .ConvTranspose2d (m .nInputPlane , m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), (m .adjW , m .adjH ))
120
+ copy_param (m , n )
121
+ add_submodule (seq , n )
116
122
elif name == 'VolumetricFullConvolution' :
117
- n = nn .ConvTranspose3d (m .nInputPlane ,m .nOutputPlane ,(m .kT ,m .kW ,m .kH ),(m .dT ,m .dW ,m .dH ),(m .padT ,m .padW ,m .padH ),(m .adjT ,m .adjW ,m .adjH ),m .groups )
118
- copy_param (m ,n )
123
+ n = nn .ConvTranspose3d (m .nInputPlane , m .nOutputPlane , (m .kT , m .kW , m .kH ), (m .dT , m .dW , m .dH ), (m .padT , m .padW , m .padH ), (m .adjT , m .adjW , m .adjH ), m .groups )
124
+ copy_param (m , n )
119
125
add_submodule (seq , n )
120
126
elif name == 'SpatialReplicationPadding' :
121
- n = nn .ReplicationPad2d ((m .pad_l ,m .pad_r ,m .pad_t ,m .pad_b ))
122
- add_submodule (seq ,n )
127
+ n = nn .ReplicationPad2d ((m .pad_l , m .pad_r , m .pad_t , m .pad_b ))
128
+ add_submodule (seq , n )
123
129
elif name == 'SpatialReflectionPadding' :
124
- n = nn .ReflectionPad2d ((m .pad_l ,m .pad_r ,m .pad_t ,m .pad_b ))
125
- add_submodule (seq ,n )
130
+ n = nn .ReflectionPad2d ((m .pad_l , m .pad_r , m .pad_t , m .pad_b ))
131
+ add_submodule (seq , n )
126
132
elif name == 'Copy' :
127
- n = Lambda (lambda x : x ) # do nothing
128
- add_submodule (seq ,n )
133
+ n = Lambda (lambda x : x ) # do nothing
134
+ add_submodule (seq , n )
129
135
elif name == 'Narrow' :
130
- n = Lambda (lambda x ,a = (m .dimension ,m .index ,m .length ): x .narrow (* a ))
131
- add_submodule (seq ,n )
136
+ n = Lambda (lambda x , a = (m .dimension , m .index , m .length ): x .narrow (* a ))
137
+ add_submodule (seq , n )
132
138
elif name == 'SpatialCrossMapLRN' :
133
- lrn = lnn .SpatialCrossMapLRN (m .size ,m .alpha ,m .beta ,m .k )
134
- n = Lambda (lambda x ,lrn = lrn : Variable (lrn .forward (x .data )))
135
- add_submodule (seq ,n )
139
+ lrn = lnn .SpatialCrossMapLRN (m .size , m .alpha , m .beta , m .k )
140
+ n = Lambda (lambda x , lrn = lrn : Variable (lrn .forward (x .data )))
141
+ add_submodule (seq , n )
136
142
elif name == 'Sequential' :
137
143
n = nn .Sequential ()
138
- lua_recursive_model (m ,n )
139
- add_submodule (seq ,n )
140
- elif name == 'ConcatTable' : # output is list
144
+ lua_recursive_model (m , n )
145
+ add_submodule (seq , n )
146
+ elif name == 'ConcatTable' : # output is list
141
147
n = LambdaMap (lambda x : x )
142
- lua_recursive_model (m ,n )
143
- add_submodule (seq ,n )
144
- elif name == 'CAddTable' : # input is list
145
- n = LambdaReduce (lambda x ,y : x + y )
146
- add_submodule (seq ,n )
148
+ lua_recursive_model (m , n )
149
+ add_submodule (seq , n )
150
+ elif name == 'CAddTable' : # input is list
151
+ n = LambdaReduce (lambda x , y : x + y )
152
+ add_submodule (seq , n )
147
153
elif name == 'Concat' :
148
154
dim = m .dimension
149
- n = LambdaReduce (lambda x ,y , dim = dim : torch .cat ((x ,y ),dim ))
150
- lua_recursive_model (m ,n )
151
- add_submodule (seq ,n )
155
+ n = LambdaReduce (lambda x , y , dim = dim : torch .cat ((x , y ), dim ))
156
+ lua_recursive_model (m , n )
157
+ add_submodule (seq , n )
152
158
elif name == 'TorchObject' :
153
- print ('Not Implement' ,name ,real ._typename )
159
+ print ('Not Implement' , name , real ._typename )
154
160
else :
155
- print ('Not Implement' ,name )
161
+ print ('Not Implement' , name )
156
162
157
163
158
164
def lua_recursive_source (module ):
@@ -161,13 +167,13 @@ def lua_recursive_source(module):
161
167
name = type (m ).__name__
162
168
real = m
163
169
if name == 'TorchObject' :
164
- name = m ._typename .replace ('cudnn.' ,'' )
170
+ name = m ._typename .replace ('cudnn.' , '' )
165
171
m = m ._obj
166
172
167
173
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM' :
168
- if not hasattr (m ,'groups' ) or m .groups is None : m .groups = 1
174
+ if not hasattr (m , 'groups' ) or m .groups is None : m .groups = 1
169
175
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d' .format (m .nInputPlane ,
170
- m .nOutputPlane ,(m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),1 , m .groups ,m .bias is not None )]
176
+ m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups , m .bias is not None )]
171
177
elif name == 'SpatialBatchNormalization' :
172
178
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
173
179
elif name == 'VolumetricBatchNormalization' :
@@ -177,9 +183,9 @@ def lua_recursive_source(module):
177
183
elif name == 'Sigmoid' :
178
184
s += ['nn.Sigmoid()' ]
179
185
elif name == 'SpatialMaxPooling' :
180
- s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d' .format ((m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),m .ceil_mode )]
186
+ s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
181
187
elif name == 'SpatialAveragePooling' :
182
- s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d' .format ((m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),m .ceil_mode )]
188
+ s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
183
189
elif name == 'SpatialUpSamplingNearest' :
184
190
s += ['nn.UpsamplingNearest2d(scale_factor={})' .format (m .scale_factor )]
185
191
elif name == 'View' :
@@ -188,8 +194,8 @@ def lua_recursive_source(module):
188
194
s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape' ]
189
195
elif name == 'Linear' :
190
196
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
191
- s2 = 'nn.Linear({},{},bias={})' .format (m .weight .size (1 ),m .weight .size (0 ),(m .bias is not None ))
192
- s += ['nn.Sequential({},{}),#Linear' .format (s1 ,s2 )]
197
+ s2 = 'nn.Linear({},{},bias={})' .format (m .weight .size (1 ), m .weight .size (0 ), (m .bias is not None ))
198
+ s += ['nn.Sequential({},{}),#Linear' .format (s1 , s2 )]
193
199
elif name == 'Dropout' :
194
200
s += ['nn.Dropout({})' .format (m .p )]
195
201
elif name == 'SoftMax' :
@@ -198,20 +204,21 @@ def lua_recursive_source(module):
198
204
s += ['Lambda(lambda x: x), # Identity' ]
199
205
elif name == 'SpatialFullConvolution' :
200
206
s += ['nn.ConvTranspose2d({},{},{},{},{},{})' .format (m .nInputPlane ,
201
- m .nOutputPlane ,(m .kW ,m .kH ),(m .dW ,m .dH ),(m .padW ,m .padH ),(m .adjW ,m .adjH ))]
207
+ m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), (m .adjW , m .adjH ))]
202
208
elif name == 'VolumetricFullConvolution' :
203
209
s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})' .format (m .nInputPlane ,
204
- m .nOutputPlane ,(m .kT ,m .kW ,m .kH ),(m .dT ,m .dW ,m .dH ),(m .padT ,m .padW ,m .padH ),(m .adjT ,m .adjW ,m .adjH ),m .groups )]
210
+ m .nOutputPlane , (m .kT , m .kW , m .kH ), (m .dT , m .dW , m .dH ), (m .padT , m .padW , m .padH ), (m .adjT , m .adjW , m .adjH ),
211
+ m .groups )]
205
212
elif name == 'SpatialReplicationPadding' :
206
- s += ['nn.ReplicationPad2d({})' .format ((m .pad_l ,m .pad_r ,m .pad_t ,m .pad_b ))]
213
+ s += ['nn.ReplicationPad2d({})' .format ((m .pad_l , m .pad_r , m .pad_t , m .pad_b ))]
207
214
elif name == 'SpatialReflectionPadding' :
208
- s += ['nn.ReflectionPad2d({})' .format ((m .pad_l ,m .pad_r ,m .pad_t ,m .pad_b ))]
215
+ s += ['nn.ReflectionPad2d({})' .format ((m .pad_l , m .pad_r , m .pad_t , m .pad_b ))]
209
216
elif name == 'Copy' :
210
217
s += ['Lambda(lambda x: x), # Copy' ]
211
218
elif name == 'Narrow' :
212
- s += ['Lambda(lambda x,a={}: x.narrow(*a))' .format ((m .dimension ,m .index ,m .length ))]
219
+ s += ['Lambda(lambda x,a={}: x.narrow(*a))' .format ((m .dimension , m .index , m .length ))]
213
220
elif name == 'SpatialCrossMapLRN' :
214
- lrn = 'lnn.SpatialCrossMapLRN(*{})' .format ((m .size ,m .alpha ,m .beta ,m .k ))
221
+ lrn = 'lnn.SpatialCrossMapLRN(*{})' .format ((m .size , m .alpha , m .beta , m .k ))
215
222
s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))' .format (lrn )]
216
223
217
224
elif name == 'Sequential' :
@@ -231,33 +238,35 @@ def lua_recursive_source(module):
231
238
s += [')' ]
232
239
else :
233
240
s += '# ' + name + ' Not Implement,\n '
234
- s = map (lambda x : '\t {}' .format (x ),s )
241
+ s = map (lambda x : '\t {}' .format (x ), s )
235
242
return s
236
243
244
+
237
245
def simplify_source (s ):
238
- s = map (lambda x : x .replace (',(1, 1),(0, 0),1,1,bias=True),#Conv2d' ,')' ),s )
239
- s = map (lambda x : x .replace (',(0, 0),1,1,bias=True),#Conv2d' ,')' ),s )
240
- s = map (lambda x : x .replace (',1,1,bias=True),#Conv2d' ,')' ),s )
241
- s = map (lambda x : x .replace (',bias=True),#Conv2d' ,')' ),s )
242
- s = map (lambda x : x .replace ('),#Conv2d' ,')' ),s )
243
- s = map (lambda x : x .replace (',1e-05,0.1,True),#BatchNorm2d' ,')' ),s )
244
- s = map (lambda x : x .replace ('),#BatchNorm2d' ,')' ),s )
245
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#MaxPool2d' ,')' ),s )
246
- s = map (lambda x : x .replace (',ceil_mode=False),#MaxPool2d' ,')' ),s )
247
- s = map (lambda x : x .replace ('),#MaxPool2d' ,')' ),s )
248
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#AvgPool2d' ,')' ),s )
249
- s = map (lambda x : x .replace (',ceil_mode=False),#AvgPool2d' ,')' ),s )
250
- s = map (lambda x : x .replace (',bias=True)),#Linear' ,')), # Linear' ),s )
251
- s = map (lambda x : x .replace (')),#Linear' ,')), # Linear' ),s )
252
-
253
- s = map (lambda x : '{},\n ' .format (x ),s )
254
- s = map (lambda x : x [1 :],s )
255
- s = reduce (lambda x ,y : x + y , s )
246
+ s = map (lambda x : x .replace (',(1, 1),(0, 0),1,1,bias=True),#Conv2d' , ')' ), s )
247
+ s = map (lambda x : x .replace (',(0, 0),1,1,bias=True),#Conv2d' , ')' ), s )
248
+ s = map (lambda x : x .replace (',1,1,bias=True),#Conv2d' , ')' ), s )
249
+ s = map (lambda x : x .replace (',bias=True),#Conv2d' , ')' ), s )
250
+ s = map (lambda x : x .replace ('),#Conv2d' , ')' ), s )
251
+ s = map (lambda x : x .replace (',1e-05,0.1,True),#BatchNorm2d' , ')' ), s )
252
+ s = map (lambda x : x .replace ('),#BatchNorm2d' , ')' ), s )
253
+ s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#MaxPool2d' , ')' ), s )
254
+ s = map (lambda x : x .replace (',ceil_mode=False),#MaxPool2d' , ')' ), s )
255
+ s = map (lambda x : x .replace ('),#MaxPool2d' , ')' ), s )
256
+ s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#AvgPool2d' , ')' ), s )
257
+ s = map (lambda x : x .replace (',ceil_mode=False),#AvgPool2d' , ')' ), s )
258
+ s = map (lambda x : x .replace (',bias=True)),#Linear' , ')), # Linear' ), s )
259
+ s = map (lambda x : x .replace (')),#Linear' , ')), # Linear' ), s )
260
+
261
+ s = map (lambda x : '{},\n ' .format (x ), s )
262
+ s = map (lambda x : x [1 :], s )
263
+ s = reduce (lambda x , y : x + y , s )
256
264
return s
257
265
258
- def torch_to_pytorch (t7_filename ,outputname = None ):
259
- model = load_lua (t7_filename ,unknown_classes = True )
260
- if type (model ).__name__ == 'hashable_uniq_dict' : model = model .model
266
+
267
+ def torch_to_pytorch (t7_filename , outputname = None ):
268
+ model = load_lua (t7_filename , unknown_classes = True )
269
+ if type (model ).__name__ == 'hashable_uniq_dict' : model = model .model
261
270
model .gradInput = None
262
271
slist = lua_recursive_source (lnn .Sequential ().add (model ))
263
272
s = simplify_source (slist )
@@ -292,23 +301,22 @@ class LambdaReduce(LambdaBase):
292
301
def forward(self, input):
293
302
return reduce(self.lambda_func,self.forward_prepare(input))
294
303
'''
295
- varname = t7_filename .replace ('.t7' ,'' ).replace ('.' ,'_' ).replace ('-' ,'_' )
296
- s = '{}\n \n {} = {}' .format (header ,varname ,s [:- 2 ])
304
+ varname = t7_filename .replace ('.t7' , '' ).replace ('.' , '_' ).replace ('-' , '_' )
305
+ s = '{}\n \n {} = {}' .format (header , varname , s [:- 2 ])
297
306
298
- if outputname is None : outputname = varname
299
- with open (outputname + '.py' , "w" ) as pyfile :
307
+ if outputname is None : outputname = varname
308
+ with open (outputname + '.py' , "w" ) as pyfile :
300
309
pyfile .write (s )
301
310
302
311
n = nn .Sequential ()
303
- lua_recursive_model (model ,n )
304
- torch .save (n .state_dict (),outputname + '.pth' )
312
+ lua_recursive_model (model , n )
313
+ torch .save (n .state_dict (), outputname + '.pth' )
305
314
306
315
307
- parser = argparse .ArgumentParser (description = 'Convert torch t7 model to pytorch' )
308
- parser .add_argument ('--model' ,'-m' , type = str , required = True ,
309
- help = 'torch model file in t7 format' )
310
- parser .add_argument ('--output' , '-o' , type = str , default = None ,
311
- help = 'output file name prefix, xxx.py xxx.pth' )
312
- args = parser .parse_args ()
316
+ if __name__ == '__main__' :
317
+ parser = argparse .ArgumentParser (description = 'Convert torch t7 model to pytorch' )
318
+ parser .add_argument ('--model' , '-m' , type = str , required = True , help = 'torch model file in t7 format' )
319
+ parser .add_argument ('--output' , '-o' , type = str , default = None , help = 'output file name prefix, xxx.py xxx.pth' )
320
+ args = parser .parse_args ()
313
321
314
- torch_to_pytorch (args .model ,args .output )
322
+ torch_to_pytorch (args .model , args .output )
0 commit comments