Skip to content

Commit 7440c9d

Browse files
committed
stem pool fix
1 parent 8f9f68b commit 7440c9d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

model_constructor/model_constructor.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def _make_stem(self):
6666
bn_layer=(not self.stem_bn_end) if i == (len(self.stem_sizes) - 2) else True,
6767
act_fn=self.act_fn, bn_1st=self.bn_1st))
6868
for i in range(len(self.stem_sizes) - 1)]
69-
stem.append(('stem_pool', self.stem_pool))
69+
if self.stem_pool is not None:
70+
stem.append(('stem_pool', self.stem_pool))
7071
if self.stem_bn_end:
7172
stem.append(('norm', self.norm(self.stem_sizes[-1])))
7273
return nn.Sequential(OrderedDict(stem))
@@ -84,9 +85,10 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
8485

8586

8687
def _make_body(self):
88+
stride = 2 if self.stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
8789
blocks = [(f"l_{i}", self._make_layer(self, self.expansion,
8890
ni=self.block_sizes[i], nf=self.block_sizes[i + 1],
89-
blocks=l, stride=1 if i == 0 else 2,
91+
blocks=l, stride=stride if i == 0 else 2,
9092
sa=self.sa if i == 0 else False))
9193
for i, l in enumerate(self.layers)]
9294
return nn.Sequential(OrderedDict(blocks))
@@ -114,7 +116,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
114116
zero_bn=True,
115117
stem_stride_on=0,
116118
stem_sizes=[32, 32, 64],
117-
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
119+
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # if stem_pool is None - no pool at stem
118120
stem_bn_end=False,
119121
_init_cnn=init_cnn,
120122
_make_stem=_make_stem,
@@ -128,6 +130,8 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
128130
del params['self']
129131
self.__dict__ = params
130132
self._block_sizes = params['block_sizes']
133+
if type(self.stem_pool) is str: # Hydra pass string value
134+
self.stem_pool = None
131135
if self.stem_sizes[0] != self.c_in:
132136
self.stem_sizes = [self.c_in] + self.stem_sizes
133137

0 commit comments

Comments
 (0)