Skip to content

Commit 38a9d76

Browse files
authored
Merge pull request #38 from ayasyrev/fix_constructor_body
Fix constructor body
2 parents 3759a02 + 7440c9d commit 38a9d76

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

model_constructor/model_constructor.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def __init__(self, expansion, ni, nh, stride=1,
3636
if div_groups is not None: # check if grops != 1 and div_groups
3737
groups = int(nh / div_groups)
3838
if expansion == 1:
39-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
40-
groups=nh if dw else groups)),
41-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
39+
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride,
40+
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
41+
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
42+
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
4243
]
4344
else:
4445
layers = [("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
@@ -65,7 +66,8 @@ def _make_stem(self):
6566
bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True,
6667
act_fn=self.act_fn, bn_1st=self.bn_1st))
6768
for i in range(len(self.stem_sizes) - 1)]
68-
stem.append(('stem_pool', self.stem_pool))
69+
if self.stem_pool is not None:
70+
stem.append(('stem_pool', self.stem_pool))
6971
if self.stem_bn_end:
7072
stem.append(('norm', self.norm(self.stem_sizes[-1])))
7173
return nn.Sequential(OrderedDict(stem))
@@ -83,9 +85,10 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8385

8486

8587
def _make_body(self):
88+
stride = 2 if self.stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8689
blocks = [(f"l_{i}", self._make_layer(self, self.expansion,
8790
ni=self.block_sizes[i], nf=self.block_sizes[i + 1],
88-
blocks=l, stride=1 if i == 0 else 2,
91+
blocks=l, stride=stride if i == 0 else 2,
8992
sa=self.sa if i == 0 else False))
9093
for i, l in enumerate(self.layers)]
9194
return nn.Sequential(OrderedDict(blocks))
@@ -113,7 +116,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
113116
zero_bn=True,
114117
stem_stride_on=0,
115118
stem_sizes=[32, 32, 64],
116-
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
119+
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # if stem_pool is None - no pool at stem
117120
stem_bn_end=False,
118121
_init_cnn=init_cnn,
119122
_make_stem=_make_stem,
@@ -127,12 +130,14 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
127130
del params['self']
128131
self.__dict__ = params
129132
self._block_sizes = params['block_sizes']
133+
if type(self.stem_pool) is str: # Hydra pass string value
134+
self.stem_pool = None
130135
if self.stem_sizes[0] != self.c_in:
131136
self.stem_sizes = [self.c_in] + self.stem_sizes
132137

133138
@property
134139
def block_sizes(self):
135-
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes + [256] * (len(self.layers) - 4)
140+
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
136141

137142
@property
138143
def stem(self):

model_constructor/yaresnet.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def __init__(self, expansion, ni, nh, stride=1,
2626
if div_groups is not None: # check if grops != 1 and div_groups
2727
groups = int(nh / div_groups)
2828
self.reduce = noop if stride == 1 else pool
29-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
30-
groups=nh if dw else groups)),
31-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
29+
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1,
30+
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
31+
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
32+
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
3233
] if expansion == 1 else [
3334
("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
3435
("conv_1", conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,

0 commit comments

Comments
 (0)