Skip to content

Commit 3508c69

Browse files
authored
Merge pull request #29 from ayasyrev/issue28_seblock
SEBlocks refactor.
2 parents 9294b0d + 5b34550 commit 3508c69

File tree

4 files changed

+76
-16
lines changed

4 files changed

+76
-16
lines changed

model_constructor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from model_constructor.model_constructor import ModelConstructor # noqa F401
1+
from model_constructor.model_constructor import ModelConstructor, ResBlock # noqa F401
22

33

44
__version__ = "0.1.8"

model_constructor/layers.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def forward(self, x):
102102
return o.view(*size).contiguous()
103103

104104

105-
class SEBlock(nn.Module):
105+
class SEBlock(nn.Module): # todo: deprecation worning.
106106
"se block"
107107
se_layer = nn.Linear
108108
act_fn = nn.ReLU(inplace=True)
@@ -126,7 +126,7 @@ def forward(self, x):
126126
return x * y.expand_as(x)
127127

128128

129-
class SEBlockConv(nn.Module):
129+
class SEBlockConv(nn.Module): # todo: deprecation worning.
130130
"se block with conv on excitation"
131131
se_layer = nn.Conv2d
132132
act_fn = nn.ReLU(inplace=True)
@@ -149,3 +149,60 @@ def forward(self, x):
149149
y = self.squeeze(x)
150150
y = self.excitation(y)
151151
return x * y.expand_as(x)
152+
153+
154+
class SEModule(nn.Module):
155+
"se block"
156+
157+
def __init__(self,
158+
channels,
159+
reduction=16,
160+
se_layer=nn.Linear,
161+
act_fn=nn.ReLU(inplace=True), # ? obj or class?
162+
use_bias=True,
163+
gate=nn.Sigmoid
164+
):
165+
super().__init__()
166+
rd_channels = channels // reduction
167+
self.squeeze = nn.AdaptiveAvgPool2d(1)
168+
self.excitation = nn.Sequential(
169+
OrderedDict([('fc_reduce', se_layer(channels, rd_channels, bias=use_bias)),
170+
('se_act', act_fn),
171+
('fc_expand', se_layer(rd_channels, channels, bias=use_bias)),
172+
('se_gate', gate())
173+
]))
174+
175+
def forward(self, x):
176+
bs, c, _, _ = x.shape
177+
y = self.squeeze(x).view(bs, c)
178+
y = self.excitation(y).view(bs, c, 1, 1)
179+
return x * y.expand_as(x)
180+
181+
182+
class SEModuleConv(nn.Module):
183+
"se block with conv on excitation"
184+
185+
def __init__(self,
186+
channels,
187+
reduction=16,
188+
se_layer=nn.Conv2d,
189+
act_fn=nn.ReLU(inplace=True),
190+
use_bias=True,
191+
gate=nn.Sigmoid
192+
):
193+
super().__init__()
194+
# rd_channels = math.ceil(channels//reduction/8)*8
195+
rd_channels = channels // reduction
196+
self.squeeze = nn.AdaptiveAvgPool2d(1)
197+
self.excitation = nn.Sequential(
198+
OrderedDict([
199+
('conv_reduce', se_layer(channels, rd_channels, 1, bias=use_bias)),
200+
('se_act', act_fn),
201+
('conv_expand', se_layer(rd_channels, channels, 1, bias=use_bias)),
202+
('gate', gate())
203+
]))
204+
205+
def forward(self, x):
206+
y = self.squeeze(x)
207+
y = self.excitation(y)
208+
return x * y.expand_as(x)

model_constructor/model_constructor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch.nn as nn
55

6-
from .layers import ConvLayer, Flatten, SEBlock, SimpleSelfAttention, noop
6+
from .layers import ConvLayer, Flatten, SEModule, SimpleSelfAttention, noop
77

88

99
__all__ = ['init_cnn', 'act_fn', 'ResBlock', 'ModelConstructor', 'xresnet34', 'xresnet50']
@@ -24,12 +24,13 @@ def init_cnn(module: nn.Module):
2424

2525
class ResBlock(nn.Module):
2626
'''Resnet block'''
27-
se_block = SEBlock
2827

2928
def __init__(self, expansion, ni, nh, stride=1,
3029
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
31-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False, se_reduction=16,
32-
groups=1, dw=False, div_groups=None):
30+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False,
31+
groups=1, dw=False, div_groups=None,
32+
se_module=SEModule, se=False, se_reduction=16
33+
):
3334
super().__init__()
3435
nf, ni = nh * expansion, ni * expansion
3536
if div_groups is not None: # check if grops != 1 and div_groups
@@ -46,7 +47,7 @@ def __init__(self, expansion, ni, nh, stride=1,
4647
("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
4748
]
4849
if se:
49-
layers.append(('se', self.se_block(nf, se_reduction)))
50+
layers.append(('se', se_module(nf, se_reduction)))
5051
if sa:
5152
layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
5253
self.convs = nn.Sequential(OrderedDict(layers))
@@ -76,7 +77,7 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
7677
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
7778
zero_bn=self.zero_bn, bn_1st=self.bn_1st,
7879
groups=self.groups, div_groups=self.div_groups,
79-
dw=self.dw, se=self.se))
80+
dw=self.dw, se_module=self.se_module, se=self.se, se_reduction=self.se_reduction))
8081
for i in range(blocks)]
8182
return nn.Sequential(OrderedDict(layers))
8283

@@ -106,7 +107,8 @@ def __init__(self, name='MC', c_in=3, c_out=1000,
106107
act_fn=nn.ReLU(inplace=True),
107108
pool=nn.AvgPool2d(2, ceil_mode=True),
108109
expansion=1, groups=1, dw=False, div_groups=None,
109-
sa=False, se=False, se_reduction=16,
110+
sa=False,
111+
se=False, se_module=SEModule, se_reduction=16,
110112
bn_1st=True,
111113
zero_bn=True,
112114
stem_stride_on=0,
@@ -150,7 +152,7 @@ def __call__(self):
150152
('body', self.body),
151153
('head', self.head)]))
152154
self._init_cnn(model)
153-
model.extra_repr = lambda: f"model {self.name}"
155+
model.extra_repr = lambda: f"{self.name}"
154156
return model
155157

156158
def __repr__(self):

model_constructor/yaresnet.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from functools import partial
66
from collections import OrderedDict
7-
from .layers import SEBlock, ConvLayer, act_fn, noop, SimpleSelfAttention
7+
from .layers import SEModule, ConvLayer, act_fn, noop, SimpleSelfAttention
88
from .net import Net
99
from torch.nn import Mish
1010

@@ -14,12 +14,13 @@
1414

1515
class YaResBlock(nn.Module):
1616
'''YaResBlock. Reduce by pool instead of stride 2'''
17-
se_block = SEBlock
1817

1918
def __init__(self, expansion, ni, nh, stride=1,
2019
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
21-
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False,
22-
groups=1, dw=False, div_groups=None):
20+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False,
21+
groups=1, dw=False, div_groups=None,
22+
se_module=SEModule, se=False, se_reduction=16
23+
):
2324
super().__init__()
2425
nf, ni = nh * expansion, ni * expansion
2526
if div_groups is not None: # check if grops != 1 and div_groups
@@ -35,7 +36,7 @@ def __init__(self, expansion, ni, nh, stride=1,
3536
("conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
3637
]
3738
if se:
38-
layers.append(('se', self.se_block(nf)))
39+
layers.append(('se', se_module(nf, se_reduction)))
3940
if sa:
4041
layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))
4142
self.convs = nn.Sequential(OrderedDict(layers))

0 commit comments

Comments
 (0)