Skip to content

Commit 82f5b6d

Browse files
author
Anatoly Baksheev
committed
formatting
1 parent 944a346 commit 82f5b6d

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

convert_torch.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import print_function
22

33
import os
4+
import re
45
import math
56
import torch
67
import argparse
@@ -14,18 +15,18 @@
1415
from torch.autograd import Variable
1516
from torch.utils.serialization import load_lua
1617

17-
from header import LambdaBase, Lambda, LambdaMap, LambdaReduce
18+
from header import LambdaBase, Lambda, LambdaMap, LambdaReduce, StatefulMaxPool2d, StatefulMaxUnpool2d
1819

1920

2021
def copy_param(m, n):
2122
if m.weight is not None:
22-
n.weight.data.copy_(m.weight, broadcast=False)
23+
n.weight.data.copy_(m.weight)
2324
if hasattr(m, 'bias') and m.bias is not None:
24-
n.bias.data.copy_(m.bias, broadcast=False)
25+
n.bias.data.copy_(m.bias)
2526
if hasattr(n, 'running_mean'):
26-
n.running_mean.copy_(m.running_mean, broadcast=False)
27+
n.running_mean.copy_(m.running_mean)
2728
if hasattr(n, 'running_var'):
28-
n.running_var.copy_(m.running_var, broadcast=False)
29+
n.running_var.copy_(m.running_var)
2930

3031

3132
def add_submodule(seq, *args):

header.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, input):
8383
return super(Dropout2d, self).forward(input)
8484

8585

86-
class StatefulMaxPool2d(nn.MaxPool2d): # object keeps indices and input sizes
86+
class StatefulMaxPool2d(nn.MaxPool2d): # object keeps indices and input sizes
8787

8888
def __init__(self, *args, **kwargs):
8989
super(StatefulMaxPool2d, self).__init__(*args, **kwargs)

0 commit comments

Comments
 (0)