|
14 | 14 | from torch.autograd import Variable
|
15 | 15 | from torch.utils.serialization import load_lua
|
16 | 16 |
|
17 |
| - |
18 |
| -class LambdaBase(nn.Sequential): |
19 |
| - def __init__(self, fn, *args): |
20 |
| - super(LambdaBase, self).__init__(*args) |
21 |
| - self.lambda_func = fn |
22 |
| - |
23 |
| - def forward_prepare(self, input): |
24 |
| - output = [] |
25 |
| - for module in self._modules.values(): |
26 |
| - output.append(module(input)) |
27 |
| - return output if output else input |
28 |
| - |
29 |
| - |
30 |
| -class Lambda(LambdaBase): |
31 |
| - def forward(self, input): |
32 |
| - return self.lambda_func(self.forward_prepare(input)) |
33 |
| - |
34 |
| - |
35 |
| -class LambdaMap(LambdaBase): |
36 |
| - def forward(self, input): |
37 |
| - # result is Variables list [Variable1, Variable2, ...] |
38 |
| - return list(map(self.lambda_func, self.forward_prepare(input))) |
39 |
| - |
40 |
| - |
41 |
| -class LambdaReduce(LambdaBase): |
42 |
| - def forward(self, input): |
43 |
| - # result is a Variable |
44 |
| - return reduce(self.lambda_func, self.forward_prepare(input)) |
| 17 | +from header import LambdaBase, Lambda, LambdaMap, LambdaReduce |
45 | 18 |
|
46 | 19 |
|
47 | 20 | def copy_param(m, n):
|
48 |
| - if m.weight is not None: n.weight.data.copy_(m.weight) |
49 |
| - if m.bias is not None: n.bias.data.copy_(m.bias) |
50 |
| - if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean) |
51 |
| - if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var) |
| 21 | + if m.weight is not None: |
| 22 | + n.weight.data.copy_(m.weight, broadcast=False) |
| 23 | + if hasattr(m, 'bias') and m.bias is not None: |
| 24 | + n.bias.data.copy_(m.bias, broadcast=False) |
| 25 | + if hasattr(n, 'running_mean'): |
| 26 | + n.running_mean.copy_(m.running_mean, broadcast=False) |
| 27 | + if hasattr(n, 'running_var'): |
| 28 | + n.running_var.copy_(m.running_var, broadcast=False) |
52 | 29 |
|
53 | 30 |
|
54 | 31 | def add_submodule(seq, *args):
|
@@ -165,7 +142,6 @@ def lua_recursive_source(module):
|
165 | 142 | s = []
|
166 | 143 | for m in module.modules:
|
167 | 144 | name = type(m).__name__
|
168 |
| - real = m |
169 | 145 | if name == 'TorchObject':
|
170 | 146 | name = m._typename.replace('cudnn.', '')
|
171 | 147 | m = m._obj
|
|
0 commit comments