Skip to content

Commit 8f9f68b

Browse files
committed
groups args at blocks at expansion == 1
1 parent 3759a02 commit 8f9f68b

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

model_constructor/model_constructor.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def __init__(self, expansion, ni, nh, stride=1,
3636
if div_groups is not None: # check if grops != 1 and div_groups
3737
groups = int(nh / div_groups)
3838
if expansion == 1:
39-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
40-
groups=nh if dw else groups)),
41-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
39+
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride,
40+
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
41+
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
42+
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
4243
]
4344
else:
4445
layers = [("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
@@ -132,7 +133,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
132133

133134
@property
134135
def block_sizes(self):
135-
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes + [256] * (len(self.layers) - 4)
136+
return [self.stem_sizes[-1] // self.expansion] + self._block_sizes
136137

137138
@property
138139
def stem(self):

model_constructor/yaresnet.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def __init__(self, expansion, ni, nh, stride=1,
2626
if div_groups is not None: # check if grops != 1 and div_groups
2727
groups = int(nh / div_groups)
2828
self.reduce = noop if stride == 1 else pool
29-
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
30-
groups=nh if dw else groups)),
31-
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
29+
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1,
30+
act_fn=act_fn, bn_1st=bn_1st, groups=ni if dw else groups)),
31+
("conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn,
32+
act=False, bn_1st=bn_1st, groups=nh if dw else groups))
3233
] if expansion == 1 else [
3334
("conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
3435
("conv_1", conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,

0 commit comments

Comments
 (0)