@@ -36,9 +36,10 @@ def __init__(self, expansion, ni, nh, stride=1,
36
36
if div_groups is not None : # check if grops != 1 and div_groups
37
37
groups = int (nh / div_groups )
38
38
if expansion == 1 :
39
- layers = [("conv_0" , conv_layer (ni , nh , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st ,
40
- groups = nh if dw else groups )),
41
- ("conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
39
+ layers = [("conv_0" , conv_layer (ni , nh , 3 , stride = stride ,
40
+ act_fn = act_fn , bn_1st = bn_1st , groups = ni if dw else groups )),
41
+ ("conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn ,
42
+ act = False , bn_1st = bn_1st , groups = nh if dw else groups ))
42
43
]
43
44
else :
44
45
layers = [("conv_0" , conv_layer (ni , nh , 1 , act_fn = act_fn , bn_1st = bn_1st )),
@@ -65,7 +66,8 @@ def _make_stem(self):
65
66
bn_layer = (not self .stem_bn_end ) if i == (len (self .stem_sizes ) - 2 ) else True ,
66
67
act_fn = self .act_fn , bn_1st = self .bn_1st ))
67
68
for i in range (len (self .stem_sizes ) - 1 )]
68
- stem .append (('stem_pool' , self .stem_pool ))
69
+ if self .stem_pool is not None :
70
+ stem .append (('stem_pool' , self .stem_pool ))
69
71
if self .stem_bn_end :
70
72
stem .append (('norm' , self .norm (self .stem_sizes [- 1 ])))
71
73
return nn .Sequential (OrderedDict (stem ))
@@ -83,9 +85,10 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
83
85
84
86
85
87
def _make_body (self ):
88
+ stride = 2 if self .stem_pool is None else 1 # if no pool on stem - stride = 2 for first block in body
86
89
blocks = [(f"l_{ i } " , self ._make_layer (self , self .expansion ,
87
90
ni = self .block_sizes [i ], nf = self .block_sizes [i + 1 ],
88
- blocks = l , stride = 1 if i == 0 else 2 ,
91
+ blocks = l , stride = stride if i == 0 else 2 ,
89
92
sa = self .sa if i == 0 else False ))
90
93
for i , l in enumerate (self .layers )]
91
94
return nn .Sequential (OrderedDict (blocks ))
@@ -113,7 +116,7 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
113
116
zero_bn = True ,
114
117
stem_stride_on = 0 ,
115
118
stem_sizes = [32 , 32 , 64 ],
116
- stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
119
+ stem_pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ), # if stem_pool is None - no pool at stem
117
120
stem_bn_end = False ,
118
121
_init_cnn = init_cnn ,
119
122
_make_stem = _make_stem ,
@@ -127,12 +130,14 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
127
130
del params ['self' ]
128
131
self .__dict__ = params
129
132
self ._block_sizes = params ['block_sizes' ]
133
+ if type (self .stem_pool ) is str : # Hydra pass string value
134
+ self .stem_pool = None
130
135
if self .stem_sizes [0 ] != self .c_in :
131
136
self .stem_sizes = [self .c_in ] + self .stem_sizes
132
137
133
138
@property
134
139
def block_sizes (self ):
135
- return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes + [ 256 ] * ( len ( self . layers ) - 4 )
140
+ return [self .stem_sizes [- 1 ] // self .expansion ] + self ._block_sizes
136
141
137
142
@property
138
143
def stem (self ):
0 commit comments