Skip to content

Commit 944a346

Browse files
author
Anatoly Baksheev
committed
moved source to separate header file
1 parent b0ee72c commit 944a346

File tree

2 files changed

+117
-31
lines changed

2 files changed

+117
-31
lines changed

convert_torch.py

+6-31
Original file line numberDiff line numberDiff line change
@@ -242,42 +242,17 @@ def simplify_source(s):
242242

243243
def torch_to_pytorch(t7_filename, outputname=None):
244244
model = load_lua(t7_filename, unknown_classes=True)
245-
if type(model).__name__ == 'hashable_uniq_dict': model = model.model
245+
if type(model).__name__ == 'hashable_uniq_dict':
246+
model = model.model
246247
model.gradInput = None
248+
247249
slist = lua_recursive_source(lnn.Sequential().add(model))
248250
s = simplify_source(slist)
249-
header = '''
250-
import torch
251-
import torch.nn as nn
252-
import torch.legacy.nn as lnn
253-
254-
from functools import reduce
255-
from torch.autograd import Variable
256-
257-
class LambdaBase(nn.Sequential):
258-
def __init__(self, fn, *args):
259-
super(LambdaBase, self).__init__(*args)
260-
self.lambda_func = fn
261-
262-
def forward_prepare(self, input):
263-
output = []
264-
for module in self._modules.values():
265-
output.append(module(input))
266-
return output if output else input
267-
268-
class Lambda(LambdaBase):
269-
def forward(self, input):
270-
return self.lambda_func(self.forward_prepare(input))
271251

272-
class LambdaMap(LambdaBase):
273-
def forward(self, input):
274-
return list(map(self.lambda_func,self.forward_prepare(input)))
252+
varname = os.path.basename(t7_filename).replace('.t7', '').replace('.', '_').replace('-', '_')
275253

276-
class LambdaReduce(LambdaBase):
277-
def forward(self, input):
278-
return reduce(self.lambda_func,self.forward_prepare(input))
279-
'''
280-
varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-', '_')
254+
with open("header.py") as f:
255+
header = f.read()
281256
s = '{}\n\n{} = {}'.format(header, varname, s[:-2])
282257

283258
if outputname is None: outputname = varname

header.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from functools import reduce
5+
from torch.autograd import Variable
6+
7+
8+
class LambdaBase(nn.Sequential):
9+
def __init__(self, fn, *args):
10+
super(LambdaBase, self).__init__(*args)
11+
self.lambda_func = fn
12+
13+
def forward_prepare(self, input):
14+
output = []
15+
for module in self._modules.values():
16+
output.append(module(input))
17+
return output if output else input
18+
19+
20+
class Lambda(LambdaBase):
21+
def forward(self, input):
22+
return self.lambda_func(self.forward_prepare(input))
23+
24+
25+
class LambdaMap(LambdaBase):
26+
def forward(self, input):
27+
# result is Variables list [Variable1, Variable2, ...]
28+
return list(map(self.lambda_func, self.forward_prepare(input)))
29+
30+
31+
class LambdaReduce(LambdaBase):
32+
def forward(self, input):
33+
# result is a Variable
34+
return reduce(self.lambda_func, self.forward_prepare(input))
35+
36+
37+
class Padding(nn.Module):
38+
# pad puts in [pad] amount of [value] over dimension [dim], starting at
39+
# index [index] in that dimension. If pad<0, index counts from the left.
40+
# If pad>0 index counts from the right.
41+
# When nInputDim is provided, inputs larger than that value will be considered batches
42+
# where the actual dim to be padded will be dimension dim + 1.
43+
def __init__(self, dim, pad, value, index, nInputDim):
44+
super(Padding, self).__init__()
45+
self.value = value
46+
# self.index = index
47+
self.dim = dim
48+
self.pad = pad
49+
self.nInputDim = nInputDim
50+
if index != 0:
51+
raise NotImplementedError("Padding: index != 0 not implemented")
52+
53+
def forward(self, input):
54+
dim = self.dim
55+
if self.nInputDim != 0:
56+
dim += input.dim() - self.nInputDim
57+
pad_size = list(input.size())
58+
pad_size[dim] = self.pad
59+
padder = Variable(input.data.new(*pad_size).fill_(self.value))
60+
61+
if self.pad < 0:
62+
padded = torch.cat((padder, input), dim)
63+
else:
64+
padded = torch.cat((input, padder), dim)
65+
return padded
66+
67+
68+
class Dropout(nn.Dropout):
69+
"""
70+
Cancel out PyTorch rescaling by 1/(1-p)
71+
"""
72+
def forward(self, input):
73+
input = input * (1 - self.p)
74+
return super(Dropout, self).forward(input)
75+
76+
77+
class Dropout2d(nn.Dropout2d):
78+
"""
79+
Cancel out PyTorch rescaling by 1/(1-p)
80+
"""
81+
def forward(self, input):
82+
input = input * (1 - self.p)
83+
return super(Dropout2d, self).forward(input)
84+
85+
86+
class StatefulMaxPool2d(nn.MaxPool2d): # object keeps indices and input sizes
87+
88+
def __init__(self, *args, **kwargs):
89+
super(StatefulMaxPool2d, self).__init__(*args, **kwargs)
90+
self.indices = None
91+
self.input_size = None
92+
93+
def forward(self, x):
94+
return_indices, self.return_indices = self.return_indices, True
95+
output, indices = super(StatefulMaxPool2d, self).forward(x)
96+
self.return_indices = return_indices
97+
self.indices = indices
98+
self.input_size = x.size()
99+
if return_indices:
100+
return output, indices
101+
return output
102+
103+
104+
class StatefulMaxUnpool2d(nn.Module):
105+
def __init__(self, pooling):
106+
super(StatefulMaxUnpool2d, self).__init__()
107+
self.pooling = pooling
108+
self.unpooling = nn.MaxUnpool2d(pooling.kernel_size, pooling.stride, pooling.padding)
109+
110+
def forward(self, x):
111+
return self.unpooling.forward(x, self.pooling.indices, self.pooling.input_size)

0 commit comments

Comments
 (0)