Skip to content

Commit a3cc9ac

Browse files
authored
Add timm efficientnet encoder (qubvel-org#189)
* Add efficientnet from timm
1 parent a48886c commit a3cc9ac

File tree

8 files changed

+489
-22
lines changed

8 files changed

+489
-22
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
120120
|efficientnet-b7 |imagenet |63M |
121121
|mobilenet_v2 |imagenet |2M |
122122
|xception |imagenet |22M |
123+
|timm-efficientnet-b0 |imagenet<br>advprop<br>noisy-student|4M |
124+
|timm-efficientnet-b1 |imagenet<br>advprop<br>noisy-student|6M |
125+
|timm-efficientnet-b2 |imagenet<br>advprop<br>noisy-student|7M |
126+
|timm-efficientnet-b3 |imagenet<br>advprop<br>noisy-student|10M |
127+
|timm-efficientnet-b4 |imagenet<br>advprop<br>noisy-student|17M |
128+
|timm-efficientnet-b5 |imagenet<br>advprop<br>noisy-student|28M |
129+
|timm-efficientnet-b6 |imagenet<br>advprop<br>noisy-student|40M |
130+
|timm-efficientnet-b7 |imagenet<br>advprop<br>noisy-student|63M |
131+
|timm-efficientnet-b8 |imagenet<br>advprop |84M |
132+
|timm-efficientnet-l2 |noisy-student |474M |
123133

124134
### Models API <a name="api"></a>
125135

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torchvision>=0.3.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch>=0.6.3
4+
timm==0.1.20

segmentation_models_pytorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .linknet import Linknet
33
from .fpn import FPN
44
from .pspnet import PSPNet
5-
from .deeplabv3 import DeepLabV3
5+
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
66
from .pan import PAN
77

88
from . import encoders
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .model import DeepLabV3
1+
from .model import DeepLabV3, DeepLabV3Plus

segmentation_models_pytorch/deeplabv3/decoder.py

+131-18
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,99 @@ def forward(self, *features):
5151
return super().forward(features[-1])
5252

5353

54+
class DeepLabV3PlusDecoder(nn.Module):
55+
def __init__(
56+
self,
57+
encoder_channels,
58+
out_channels=256,
59+
atrous_rates=(12, 24, 36),
60+
output_stride=16,
61+
):
62+
super().__init__()
63+
if output_stride not in {8, 16}:
64+
raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride))
65+
66+
self.out_channels = out_channels
67+
self.output_stride = output_stride
68+
69+
self.aspp = nn.Sequential(
70+
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
71+
SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
72+
nn.BatchNorm2d(out_channels),
73+
nn.ReLU(),
74+
)
75+
76+
scale_factor = 2 if output_stride == 8 else 4
77+
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
78+
79+
highres_in_channels = encoder_channels[-4]
80+
highres_out_channels = 48 # proposed by authors of paper
81+
self.block1 = nn.Sequential(
82+
nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False),
83+
nn.BatchNorm2d(highres_out_channels),
84+
nn.ReLU(),
85+
)
86+
self.block2 = nn.Sequential(
87+
SeparableConv2d(
88+
highres_out_channels + out_channels,
89+
out_channels,
90+
kernel_size=3,
91+
padding=1,
92+
bias=False,
93+
),
94+
nn.BatchNorm2d(out_channels),
95+
nn.ReLU(),
96+
)
97+
98+
def forward(self, *features):
99+
aspp_features = self.aspp(features[-1])
100+
aspp_features = self.up(aspp_features)
101+
high_res_features = self.block1(features[-4])
102+
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
103+
fused_features = self.block2(concat_features)
104+
return fused_features
105+
106+
54107
class ASPPConv(nn.Sequential):
55108
def __init__(self, in_channels, out_channels, dilation):
56-
modules = [
57-
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
109+
super().__init__(
110+
nn.Conv2d(
111+
in_channels,
112+
out_channels,
113+
kernel_size=3,
114+
padding=dilation,
115+
dilation=dilation,
116+
bias=False,
117+
),
118+
nn.BatchNorm2d(out_channels),
119+
nn.ReLU(),
120+
)
121+
122+
123+
class ASPPSeparableConv(nn.Sequential):
124+
def __init__(self, in_channels, out_channels, dilation):
125+
super().__init__(
126+
SeparableConv2d(
127+
in_channels,
128+
out_channels,
129+
kernel_size=3,
130+
padding=dilation,
131+
dilation=dilation,
132+
bias=False,
133+
),
58134
nn.BatchNorm2d(out_channels),
59-
nn.ReLU()
60-
]
61-
super(ASPPConv, self).__init__(*modules)
135+
nn.ReLU(),
136+
)
62137

63138

64139
class ASPPPooling(nn.Sequential):
65140
def __init__(self, in_channels, out_channels):
66-
super(ASPPPooling, self).__init__(
141+
super().__init__(
67142
nn.AdaptiveAvgPool2d(1),
68-
nn.Conv2d(in_channels, out_channels, 1, bias=False),
143+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
69144
nn.BatchNorm2d(out_channels),
70-
nn.ReLU())
145+
nn.ReLU(),
146+
)
71147

72148
def forward(self, x):
73149
size = x.shape[-2:]
@@ -77,31 +153,68 @@ def forward(self, x):
77153

78154

79155
class ASPP(nn.Module):
80-
def __init__(self, in_channels, out_channels, atrous_rates):
156+
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
81157
super(ASPP, self).__init__()
82158
modules = []
83-
modules.append(nn.Sequential(
84-
nn.Conv2d(in_channels, out_channels, 1, bias=False),
85-
nn.BatchNorm2d(out_channels),
86-
nn.ReLU()))
159+
modules.append(
160+
nn.Sequential(
161+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
162+
nn.BatchNorm2d(out_channels),
163+
nn.ReLU(),
164+
)
165+
)
87166

88167
rate1, rate2, rate3 = tuple(atrous_rates)
89-
modules.append(ASPPConv(in_channels, out_channels, rate1))
90-
modules.append(ASPPConv(in_channels, out_channels, rate2))
91-
modules.append(ASPPConv(in_channels, out_channels, rate3))
168+
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
169+
170+
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
171+
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
172+
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
92173
modules.append(ASPPPooling(in_channels, out_channels))
93174

94175
self.convs = nn.ModuleList(modules)
95176

96177
self.project = nn.Sequential(
97-
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
178+
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
98179
nn.BatchNorm2d(out_channels),
99180
nn.ReLU(),
100-
nn.Dropout(0.5))
181+
nn.Dropout(0.5),
182+
)
101183

102184
def forward(self, x):
103185
res = []
104186
for conv in self.convs:
105187
res.append(conv(x))
106188
res = torch.cat(res, dim=1)
107189
return self.project(res)
190+
191+
192+
class SeparableConv2d(nn.Sequential):
193+
194+
def __init__(
195+
self,
196+
in_channels,
197+
out_channels,
198+
kernel_size,
199+
stride=1,
200+
padding=0,
201+
dilation=1,
202+
bias=True,
203+
):
204+
dephtwise_conv = nn.Conv2d(
205+
in_channels,
206+
in_channels,
207+
kernel_size,
208+
stride=stride,
209+
padding=padding,
210+
dilation=dilation,
211+
groups=in_channels,
212+
bias=False,
213+
)
214+
pointwise_conv = nn.Conv2d(
215+
in_channels,
216+
out_channels,
217+
kernel_size=1,
218+
bias=bias,
219+
)
220+
super().__init__(dephtwise_conv, pointwise_conv)

segmentation_models_pytorch/deeplabv3/model.py

+94-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch.nn as nn
22

33
from typing import Optional
4-
from .decoder import DeepLabV3Decoder
4+
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
55
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
66
from ..encoders import get_encoder
77

@@ -79,3 +79,96 @@ def __init__(
7979
)
8080
else:
8181
self.classification_head = None
82+
83+
84+
class DeepLabV3Plus(SegmentationModel):
85+
"""DeepLabV3Plus_ implemetation from "Encoder-Decoder with Atrous Separable
86+
Convolution for Semantic Image Segmentation"
87+
Args:
88+
encoder_name: name of classification model (without last dense layers) used as feature
89+
extractor to build segmentation model.
90+
encoder_depth: number of stages used in decoder, larger depth - more features are generated.
91+
e.g. for depth=3 encoder will generate list of features with following spatial shapes
92+
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature will have
93+
spatial resolution (H/(2^depth), W/(2^depth)]
94+
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
95+
encoder_output_stride: downsampling factor for deepest encoder features (see original paper for explanation)
96+
decoder_atrous_rates: dilation rates for ASPP module (should be a tuple of 3 integer values)
97+
decoder_channels: a number of convolution filters in ASPP module (default 256).
98+
in_channels: number of input channels for model, default is 3.
99+
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
100+
activation (str, callable): activation function used in ``.predict(x)`` method for inference.
101+
One of [``sigmoid``, ``softmax2d``, callable, None]
102+
upsampling: optional, final upsampling factor
103+
(default is 8 to preserve input -> output spatial shape identity)
104+
aux_params: if specified model will have additional classification auxiliary output
105+
build on top of encoder, supported params:
106+
- classes (int): number of classes
107+
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
108+
- dropout (float): dropout factor in [0, 1)
109+
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
110+
Returns:
111+
``torch.nn.Module``: **DeepLabV3Plus**
112+
.. _DeeplabV3Plus:
113+
https://arxiv.org/abs/1802.02611v3
114+
"""
115+
def __init__(
116+
self,
117+
encoder_name: str = "resnet34",
118+
encoder_depth: int = 5,
119+
encoder_weights: Optional[str] = "imagenet",
120+
encoder_output_stride: int = 16,
121+
decoder_channels: int = 256,
122+
decoder_atrous_rates: tuple = (12, 24, 36),
123+
in_channels: int = 3,
124+
classes: int = 1,
125+
activation: Optional[str] = None,
126+
upsampling: int = 4,
127+
aux_params: Optional[dict] = None,
128+
):
129+
super().__init__()
130+
131+
self.encoder = get_encoder(
132+
encoder_name,
133+
in_channels=in_channels,
134+
depth=encoder_depth,
135+
weights=encoder_weights,
136+
)
137+
138+
if encoder_output_stride == 8:
139+
self.encoder.make_dilated(
140+
stage_list=[4, 5],
141+
dilation_list=[2, 4]
142+
)
143+
144+
elif encoder_output_stride == 16:
145+
self.encoder.make_dilated(
146+
stage_list=[5],
147+
dilation_list=[2]
148+
)
149+
else:
150+
raise ValueError(
151+
"Encoder output stride should be 8 or 16, got {}".format(encoder_output_stride)
152+
)
153+
154+
self.decoder = DeepLabV3PlusDecoder(
155+
encoder_channels=self.encoder.out_channels,
156+
out_channels=decoder_channels,
157+
atrous_rates=decoder_atrous_rates,
158+
output_stride=encoder_output_stride,
159+
)
160+
161+
self.segmentation_head = SegmentationHead(
162+
in_channels=self.decoder.out_channels,
163+
out_channels=classes,
164+
activation=activation,
165+
kernel_size=1,
166+
upsampling=upsampling,
167+
)
168+
169+
if aux_params is not None:
170+
self.classification_head = ClassificationHead(
171+
in_channels=self.encoder.out_channels[-1], **aux_params
172+
)
173+
else:
174+
self.classification_head = None

segmentation_models_pytorch/encoders/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .efficientnet import efficient_net_encoders
1212
from .mobilenet import mobilenet_encoders
1313
from .xception import xception_encoders
14-
14+
from .timm_efficientnet import timm_efficientnet_encoders
1515

1616
from ._preprocessing import preprocess_input
1717

@@ -26,6 +26,7 @@
2626
encoders.update(efficient_net_encoders)
2727
encoders.update(mobilenet_encoders)
2828
encoders.update(xception_encoders)
29+
encoders.update(timm_efficientnet_encoders)
2930

3031

3132
def get_encoder(name, in_channels=3, depth=5, weights=None):

0 commit comments

Comments
 (0)