Skip to content

Commit b0ee72c

Browse files
author
Anatoly Baksheev
committed
formatting
1 parent 1741ed0 commit b0ee72c

File tree

2 files changed

+9
-33
lines changed

2 files changed

+9
-33
lines changed

convert_torch.py

+9-33
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,18 @@
1414
from torch.autograd import Variable
1515
from torch.utils.serialization import load_lua
1616

17-
18-
class LambdaBase(nn.Sequential):
19-
def __init__(self, fn, *args):
20-
super(LambdaBase, self).__init__(*args)
21-
self.lambda_func = fn
22-
23-
def forward_prepare(self, input):
24-
output = []
25-
for module in self._modules.values():
26-
output.append(module(input))
27-
return output if output else input
28-
29-
30-
class Lambda(LambdaBase):
31-
def forward(self, input):
32-
return self.lambda_func(self.forward_prepare(input))
33-
34-
35-
class LambdaMap(LambdaBase):
36-
def forward(self, input):
37-
# result is Variables list [Variable1, Variable2, ...]
38-
return list(map(self.lambda_func, self.forward_prepare(input)))
39-
40-
41-
class LambdaReduce(LambdaBase):
42-
def forward(self, input):
43-
# result is a Variable
44-
return reduce(self.lambda_func, self.forward_prepare(input))
17+
from header import LambdaBase, Lambda, LambdaMap, LambdaReduce
4518

4619

4720
def copy_param(m, n):
48-
if m.weight is not None: n.weight.data.copy_(m.weight)
49-
if m.bias is not None: n.bias.data.copy_(m.bias)
50-
if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean)
51-
if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var)
21+
if m.weight is not None:
22+
n.weight.data.copy_(m.weight, broadcast=False)
23+
if hasattr(m, 'bias') and m.bias is not None:
24+
n.bias.data.copy_(m.bias, broadcast=False)
25+
if hasattr(n, 'running_mean'):
26+
n.running_mean.copy_(m.running_mean, broadcast=False)
27+
if hasattr(n, 'running_var'):
28+
n.running_var.copy_(m.running_var, broadcast=False)
5229

5330

5431
def add_submodule(seq, *args):
@@ -165,7 +142,6 @@ def lua_recursive_source(module):
165142
s = []
166143
for m in module.modules:
167144
name = type(m).__name__
168-
real = m
169145
if name == 'TorchObject':
170146
name = m._typename.replace('cudnn.', '')
171147
m = m._obj

header.py

Whitespace-only changes.

0 commit comments

Comments
 (0)