Skip to content

Commit 1741ed0

Browse files
author
Anatoly Baksheev
committed
formatting, if name == main
1 parent 58daa35 commit 1741ed0

File tree

1 file changed

+119
-111
lines changed

1 file changed

+119
-111
lines changed

convert_torch.py

+119-111
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.autograd import Variable
1515
from torch.utils.serialization import load_lua
1616

17+
1718
class LambdaBase(nn.Sequential):
1819
def __init__(self, fn, *args):
1920
super(LambdaBase, self).__init__(*args)
@@ -25,134 +26,139 @@ def forward_prepare(self, input):
2526
output.append(module(input))
2627
return output if output else input
2728

29+
2830
class Lambda(LambdaBase):
2931
def forward(self, input):
3032
return self.lambda_func(self.forward_prepare(input))
3133

34+
3235
class LambdaMap(LambdaBase):
3336
def forward(self, input):
3437
# result is Variables list [Variable1, Variable2, ...]
35-
return list(map(self.lambda_func,self.forward_prepare(input)))
38+
return list(map(self.lambda_func, self.forward_prepare(input)))
39+
3640

3741
class LambdaReduce(LambdaBase):
3842
def forward(self, input):
3943
# result is a Variable
40-
return reduce(self.lambda_func,self.forward_prepare(input))
44+
return reduce(self.lambda_func, self.forward_prepare(input))
4145

4246

43-
def copy_param(m,n):
47+
def copy_param(m, n):
4448
if m.weight is not None: n.weight.data.copy_(m.weight)
4549
if m.bias is not None: n.bias.data.copy_(m.bias)
46-
if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean)
47-
if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)
50+
if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean)
51+
if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var)
52+
4853

4954
def add_submodule(seq, *args):
5055
for n in args:
51-
seq.add_module(str(len(seq._modules)),n)
56+
seq.add_module(str(len(seq._modules)), n)
57+
5258

53-
def lua_recursive_model(module,seq):
59+
def lua_recursive_model(module, seq):
5460
for m in module.modules:
5561
name = type(m).__name__
5662
real = m
5763
if name == 'TorchObject':
58-
name = m._typename.replace('cudnn.','')
64+
name = m._typename.replace('cudnn.', '')
5965
m = m._obj
6066

6167
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
62-
if not hasattr(m,'groups') or m.groups is None: m.groups=1
63-
n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None))
64-
copy_param(m,n)
65-
add_submodule(seq,n)
68+
if not hasattr(m, 'groups') or m.groups is None: m.groups = 1
69+
n = nn.Conv2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, bias=(m.bias is not None))
70+
copy_param(m, n)
71+
add_submodule(seq, n)
6672
elif name == 'SpatialBatchNormalization':
6773
n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
68-
copy_param(m,n)
69-
add_submodule(seq,n)
74+
copy_param(m, n)
75+
add_submodule(seq, n)
7076
elif name == 'VolumetricBatchNormalization':
7177
n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
7278
copy_param(m, n)
7379
add_submodule(seq, n)
7480
elif name == 'ReLU':
7581
n = nn.ReLU()
76-
add_submodule(seq,n)
82+
add_submodule(seq, n)
7783
elif name == 'Sigmoid':
7884
n = nn.Sigmoid()
79-
add_submodule(seq,n)
85+
add_submodule(seq, n)
8086
elif name == 'SpatialMaxPooling':
81-
n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
82-
add_submodule(seq,n)
87+
n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
88+
add_submodule(seq, n)
8389
elif name == 'SpatialAveragePooling':
84-
n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
85-
add_submodule(seq,n)
90+
n = nn.AvgPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
91+
add_submodule(seq, n)
8692
elif name == 'SpatialUpSamplingNearest':
8793
n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
88-
add_submodule(seq,n)
94+
add_submodule(seq, n)
8995
elif name == 'View':
90-
n = Lambda(lambda x: x.view(x.size(0),-1))
91-
add_submodule(seq,n)
96+
n = Lambda(lambda x: x.view(x.size(0), -1))
97+
add_submodule(seq, n)
9298
elif name == 'Reshape':
93-
n = Lambda(lambda x: x.view(x.size(0),-1))
94-
add_submodule(seq,n)
99+
n = Lambda(lambda x: x.view(x.size(0), -1))
100+
add_submodule(seq, n)
95101
elif name == 'Linear':
96102
# Linear in pytorch only accept 2D input
97-
n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )
98-
n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None))
99-
copy_param(m,n2)
100-
n = nn.Sequential(n1,n2)
101-
add_submodule(seq,n)
103+
n1 = Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x)
104+
n2 = nn.Linear(m.weight.size(1), m.weight.size(0), bias=(m.bias is not None))
105+
copy_param(m, n2)
106+
n = nn.Sequential(n1, n2)
107+
add_submodule(seq, n)
102108
elif name == 'Dropout':
103109
m.inplace = False
104110
n = nn.Dropout(m.p)
105-
add_submodule(seq,n)
111+
add_submodule(seq, n)
106112
elif name == 'SoftMax':
107113
n = nn.Softmax()
108-
add_submodule(seq,n)
114+
add_submodule(seq, n)
109115
elif name == 'Identity':
110-
n = Lambda(lambda x: x) # do nothing
111-
add_submodule(seq,n)
116+
n = Lambda(lambda x: x) # do nothing
117+
add_submodule(seq, n)
112118
elif name == 'SpatialFullConvolution':
113-
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))
114-
copy_param(m,n)
115-
add_submodule(seq,n)
119+
n = nn.ConvTranspose2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), (m.adjW, m.adjH))
120+
copy_param(m, n)
121+
add_submodule(seq, n)
116122
elif name == 'VolumetricFullConvolution':
117-
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)
118-
copy_param(m,n)
123+
n = nn.ConvTranspose3d(m.nInputPlane, m.nOutputPlane, (m.kT, m.kW, m.kH), (m.dT, m.dW, m.dH), (m.padT, m.padW, m.padH), (m.adjT, m.adjW, m.adjH), m.groups)
124+
copy_param(m, n)
119125
add_submodule(seq, n)
120126
elif name == 'SpatialReplicationPadding':
121-
n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
122-
add_submodule(seq,n)
127+
n = nn.ReplicationPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
128+
add_submodule(seq, n)
123129
elif name == 'SpatialReflectionPadding':
124-
n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
125-
add_submodule(seq,n)
130+
n = nn.ReflectionPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
131+
add_submodule(seq, n)
126132
elif name == 'Copy':
127-
n = Lambda(lambda x: x) # do nothing
128-
add_submodule(seq,n)
133+
n = Lambda(lambda x: x) # do nothing
134+
add_submodule(seq, n)
129135
elif name == 'Narrow':
130-
n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
131-
add_submodule(seq,n)
136+
n = Lambda(lambda x, a=(m.dimension, m.index, m.length): x.narrow(*a))
137+
add_submodule(seq, n)
132138
elif name == 'SpatialCrossMapLRN':
133-
lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
134-
n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
135-
add_submodule(seq,n)
139+
lrn = lnn.SpatialCrossMapLRN(m.size, m.alpha, m.beta, m.k)
140+
n = Lambda(lambda x, lrn=lrn: Variable(lrn.forward(x.data)))
141+
add_submodule(seq, n)
136142
elif name == 'Sequential':
137143
n = nn.Sequential()
138-
lua_recursive_model(m,n)
139-
add_submodule(seq,n)
140-
elif name == 'ConcatTable': # output is list
144+
lua_recursive_model(m, n)
145+
add_submodule(seq, n)
146+
elif name == 'ConcatTable': # output is list
141147
n = LambdaMap(lambda x: x)
142-
lua_recursive_model(m,n)
143-
add_submodule(seq,n)
144-
elif name == 'CAddTable': # input is list
145-
n = LambdaReduce(lambda x,y: x+y)
146-
add_submodule(seq,n)
148+
lua_recursive_model(m, n)
149+
add_submodule(seq, n)
150+
elif name == 'CAddTable': # input is list
151+
n = LambdaReduce(lambda x, y: x + y)
152+
add_submodule(seq, n)
147153
elif name == 'Concat':
148154
dim = m.dimension
149-
n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim))
150-
lua_recursive_model(m,n)
151-
add_submodule(seq,n)
155+
n = LambdaReduce(lambda x, y, dim=dim: torch.cat((x, y), dim))
156+
lua_recursive_model(m, n)
157+
add_submodule(seq, n)
152158
elif name == 'TorchObject':
153-
print('Not Implement',name,real._typename)
159+
print('Not Implement', name, real._typename)
154160
else:
155-
print('Not Implement',name)
161+
print('Not Implement', name)
156162

157163

158164
def lua_recursive_source(module):
@@ -161,13 +167,13 @@ def lua_recursive_source(module):
161167
name = type(m).__name__
162168
real = m
163169
if name == 'TorchObject':
164-
name = m._typename.replace('cudnn.','')
170+
name = m._typename.replace('cudnn.', '')
165171
m = m._obj
166172

167173
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
168-
if not hasattr(m,'groups') or m.groups is None: m.groups=1
174+
if not hasattr(m, 'groups') or m.groups is None: m.groups = 1
169175
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
170-
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
176+
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, m.bias is not None)]
171177
elif name == 'SpatialBatchNormalization':
172178
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
173179
elif name == 'VolumetricBatchNormalization':
@@ -177,9 +183,9 @@ def lua_recursive_source(module):
177183
elif name == 'Sigmoid':
178184
s += ['nn.Sigmoid()']
179185
elif name == 'SpatialMaxPooling':
180-
s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
186+
s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
181187
elif name == 'SpatialAveragePooling':
182-
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
188+
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
183189
elif name == 'SpatialUpSamplingNearest':
184190
s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
185191
elif name == 'View':
@@ -188,8 +194,8 @@ def lua_recursive_source(module):
188194
s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape']
189195
elif name == 'Linear':
190196
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
191-
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None))
192-
s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)]
197+
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1), m.weight.size(0), (m.bias is not None))
198+
s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
193199
elif name == 'Dropout':
194200
s += ['nn.Dropout({})'.format(m.p)]
195201
elif name == 'SoftMax':
@@ -198,20 +204,21 @@ def lua_recursive_source(module):
198204
s += ['Lambda(lambda x: x), # Identity']
199205
elif name == 'SpatialFullConvolution':
200206
s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane,
201-
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))]
207+
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), (m.adjW, m.adjH))]
202208
elif name == 'VolumetricFullConvolution':
203209
s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane,
204-
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)]
210+
m.nOutputPlane, (m.kT, m.kW, m.kH), (m.dT, m.dW, m.dH), (m.padT, m.padW, m.padH), (m.adjT, m.adjW, m.adjH),
211+
m.groups)]
205212
elif name == 'SpatialReplicationPadding':
206-
s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
213+
s += ['nn.ReplicationPad2d({})'.format((m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
207214
elif name == 'SpatialReflectionPadding':
208-
s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
215+
s += ['nn.ReflectionPad2d({})'.format((m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
209216
elif name == 'Copy':
210217
s += ['Lambda(lambda x: x), # Copy']
211218
elif name == 'Narrow':
212-
s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
219+
s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension, m.index, m.length))]
213220
elif name == 'SpatialCrossMapLRN':
214-
lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
221+
lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size, m.alpha, m.beta, m.k))
215222
s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
216223

217224
elif name == 'Sequential':
@@ -231,33 +238,35 @@ def lua_recursive_source(module):
231238
s += [')']
232239
else:
233240
s += '# ' + name + ' Not Implement,\n'
234-
s = map(lambda x: '\t{}'.format(x),s)
241+
s = map(lambda x: '\t{}'.format(x), s)
235242
return s
236243

244+
237245
def simplify_source(s):
238-
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s)
239-
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s)
240-
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s)
241-
s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s)
242-
s = map(lambda x: x.replace('),#Conv2d',')'),s)
243-
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s)
244-
s = map(lambda x: x.replace('),#BatchNorm2d',')'),s)
245-
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s)
246-
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s)
247-
s = map(lambda x: x.replace('),#MaxPool2d',')'),s)
248-
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s)
249-
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
250-
s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
251-
s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
252-
253-
s = map(lambda x: '{},\n'.format(x),s)
254-
s = map(lambda x: x[1:],s)
255-
s = reduce(lambda x,y: x+y, s)
246+
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'), s)
247+
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
248+
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
249+
s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
250+
s = map(lambda x: x.replace('),#Conv2d', ')'), s)
251+
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
252+
s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
253+
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
254+
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
255+
s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
256+
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
257+
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
258+
s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
259+
s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)
260+
261+
s = map(lambda x: '{},\n'.format(x), s)
262+
s = map(lambda x: x[1:], s)
263+
s = reduce(lambda x, y: x + y, s)
256264
return s
257265

258-
def torch_to_pytorch(t7_filename,outputname=None):
259-
model = load_lua(t7_filename,unknown_classes=True)
260-
if type(model).__name__=='hashable_uniq_dict': model=model.model
266+
267+
def torch_to_pytorch(t7_filename, outputname=None):
268+
model = load_lua(t7_filename, unknown_classes=True)
269+
if type(model).__name__ == 'hashable_uniq_dict': model = model.model
261270
model.gradInput = None
262271
slist = lua_recursive_source(lnn.Sequential().add(model))
263272
s = simplify_source(slist)
@@ -292,23 +301,22 @@ class LambdaReduce(LambdaBase):
292301
def forward(self, input):
293302
return reduce(self.lambda_func,self.forward_prepare(input))
294303
'''
295-
varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
296-
s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
304+
varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-', '_')
305+
s = '{}\n\n{} = {}'.format(header, varname, s[:-2])
297306

298-
if outputname is None: outputname=varname
299-
with open(outputname+'.py', "w") as pyfile:
307+
if outputname is None: outputname = varname
308+
with open(outputname + '.py', "w") as pyfile:
300309
pyfile.write(s)
301310

302311
n = nn.Sequential()
303-
lua_recursive_model(model,n)
304-
torch.save(n.state_dict(),outputname+'.pth')
312+
lua_recursive_model(model, n)
313+
torch.save(n.state_dict(), outputname + '.pth')
305314

306315

307-
parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
308-
parser.add_argument('--model','-m', type=str, required=True,
309-
help='torch model file in t7 format')
310-
parser.add_argument('--output', '-o', type=str, default=None,
311-
help='output file name prefix, xxx.py xxx.pth')
312-
args = parser.parse_args()
316+
if __name__ == '__main__':
317+
parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
318+
parser.add_argument('--model', '-m', type=str, required=True, help='torch model file in t7 format')
319+
parser.add_argument('--output', '-o', type=str, default=None, help='output file name prefix, xxx.py xxx.pth')
320+
args = parser.parse_args()
313321

314-
torch_to_pytorch(args.model,args.output)
322+
torch_to_pytorch(args.model, args.output)

0 commit comments

Comments
 (0)