Skip to content

Commit 9a8e558

Browse files
committed
SE block add rd_channal arg
1 parent 7440c9d commit 9a8e558

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

model_constructor/layers.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,20 @@ class SEModule(nn.Module):
157157
def __init__(self,
158158
channels,
159159
reduction=16,
160+
rd_channels=None,
161+
rd_max=False,
160162
se_layer=nn.Linear,
161163
act_fn=nn.ReLU(inplace=True), # ? obj or class?
162164
use_bias=True,
163165
gate=nn.Sigmoid
164166
):
165167
super().__init__()
166-
rd_channels = channels // reduction
168+
reducted = channels // reduction
169+
if rd_channels is None:
170+
rd_channels = reducted
171+
else:
172+
if rd_max:
173+
rd_channels = max(rd_channels, reducted)
167174
self.squeeze = nn.AdaptiveAvgPool2d(1)
168175
self.excitation = nn.Sequential(
169176
OrderedDict([('fc_reduce', se_layer(channels, rd_channels, bias=use_bias)),
@@ -185,14 +192,21 @@ class SEModuleConv(nn.Module):
185192
def __init__(self,
186193
channels,
187194
reduction=16,
195+
rd_channels=None,
196+
rd_max=False,
188197
se_layer=nn.Conv2d,
189198
act_fn=nn.ReLU(inplace=True),
190199
use_bias=True,
191200
gate=nn.Sigmoid
192201
):
193202
super().__init__()
194203
# rd_channels = math.ceil(channels//reduction/8)*8
195-
rd_channels = channels // reduction
204+
reducted = channels // reduction
205+
if rd_channels is None:
206+
rd_channels = reducted
207+
else:
208+
if rd_max:
209+
rd_channels = max(rd_channels, reducted)
196210
self.squeeze = nn.AdaptiveAvgPool2d(1)
197211
self.excitation = nn.Sequential(
198212
OrderedDict([

0 commit comments

Comments
 (0)