Skip to content

Commit d1b4fac

Browse files
authored
add UNet3P (PaddlePaddle#906)
1 parent 0d683da commit d1b4fac

File tree

7 files changed

+327
-0
lines changed

7 files changed

+327
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Welcome to PaddleSeg! PaddleSeg is an end-to-end image segmentation development
4242
|[U<sup>2</sup>-Net](./configs/u2net)|-|-|-|-|
4343
|[Att U-Net](./configs/attention_unet)|-|-|-|-|
4444
|[U-Net++](./configs/unet_plusplus)|-|-|-|-|
45+
|[U-Net3+](./configs/unet_3plus)|-|-|-|-|
4546
|[DecoupledSegNet](./configs/decoupled_segnet)|||||
4647
|[EMANet](./configs/emanet)|||-|-|
4748
|[ISANet](./configs/isanet)|||-|-|

README_CN.md

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ PaddleSeg是基于飞桨[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的
4242
|[U<sup>2</sup>-Net](./configs/u2net)|-|-|-|-|
4343
|[Att U-Net](./configs/attention_unet)|-|-|-|-|
4444
|[U-Net++](./configs/unet_plusplus)|-|-|-|-|
45+
|[U-Net3+](./configs/unet_3plus)|-|-|-|-|
4546
|[DecoupledSegNet](./configs/decoupled_segnet)|||||
4647
|[EMANet](./configs/emanet)|||-|-|
4748
|[ISANet](./configs/isanet)|||-|-|

configs/unet_3plus/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation
2+
3+
## Reference
4+
5+
> Huang H , Lin L , Tong R , et al. UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation[J]. ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2020.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
_base_: '../_base_/cityscapes.yml'
2+
3+
batch_size: 4
4+
iters: 160000
5+
6+
model:
7+
type: UNet3Plus
8+
in_channels: 3
9+
num_classes: 19
10+
is_batchnorm: True
11+
is_deepsup: False
12+
is_CGM: False

docs/apis/models.md

+22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The models subpackage contains the following model for image sementic segmentaio
1818
- [U<sup>2</sup>Net+](#U2Net-1)
1919
- [AttentionUNet](#AttentionUNet)
2020
- [UNet++](#UNet-1)
21+
- [UNet3+](#UNet-2)
2122
- [DecoupledSegNet](#DecoupledSegNet)
2223
- [ISANet](#ISANet)
2324
- [EMANet](#EMANet)
@@ -408,6 +409,27 @@ The models subpackage contains the following model for image sementic segmentaio
408409
> > > - **pretrained** (str, optional): The path or url of pretrained model for fine tuning. Default: None.
409410
> > > - **is_ds** (bool): use deep supervision or not. Default: True
410411
412+
## <span id="UNet-2">[UNet3+](../../paddleseg/models/unet_3plus.py)</span>
413+
> class UNet3Plus(in_channels,
414+
num_classes,
415+
is_batchnorm=True,
416+
is_deepsup=False,
417+
is_CGM=False)
418+
419+
The UNet3+ implementation based on PaddlePaddle.
420+
421+
The original article refers to
422+
Huang H , Lin L , Tong R , et al. "UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation"
423+
(https://arxiv.org/abs/2004.08790).
424+
425+
> > Args
426+
> > > - **in_channels** (int): The channel number of input image.
427+
> > > - **num_classes** (int): The unique number of target classes.
428+
> > > - **is_batchnorm** (bool, optional) Use batchnorm after conv or not. Default: True.
429+
> > > - **is_deepsup** (bool, optional): Use deep supervision or not. Default: False.
430+
> > > - **is_CGM** (bool, optional): Use classification-guided module or not.
431+
If True, is_deepsup must be True. Default: False.
432+
411433
## [DecoupledSegNet](../../paddleseg/models/decoupled_segnet.py)
412434
> class DecoupledSegNet(num_classes,
413435
backbone,

paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .u2net import U2Net, U2Netp
3131
from .attention_unet import AttentionUNet
3232
from .unet_plusplus import UNetPlusPlus
33+
from .unet_3plus import UNet3Plus
3334
from .decoupled_segnet import DecoupledSegNet
3435
from .emanet import *
3536
from .isanet import *

paddleseg/models/unet_3plus.py

+285
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg.cvlibs import manager
20+
from paddleseg.models.layers.layer_libs import SyncBatchNorm
21+
from paddleseg.cvlibs.param_init import kaiming_normal_init
22+
23+
24+
@manager.MODELS.add_component
25+
class UNet3Plus(nn.Layer):
26+
"""
27+
The UNet3+ implementation based on PaddlePaddle.
28+
29+
The original article refers to
30+
Huang H , Lin L , Tong R , et al. "UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation"
31+
(https://arxiv.org/abs/2004.08790).
32+
33+
Args:
34+
in_channels (int, optional): The channel number of input image. Default: 3.
35+
num_classes (int, optional): The unique number of target classes. Default: 2.
36+
is_batchnorm (bool, optional): Use batchnorm after conv or not. Default: True.
37+
is_deepsup (bool, optional): Use deep supervision or not. Default: False.
38+
is_CGM (bool, optional): Use classification-guided module or not.
39+
If True, is_deepsup must be True. Default: False.
40+
"""
41+
def __init__(self, in_channels=3, num_classes=2, is_batchnorm=True, is_deepsup=False, is_CGM=False):
42+
super(UNet3Plus, self).__init__()
43+
# parameters
44+
self.is_deepsup = True if is_CGM else is_deepsup
45+
self.is_CGM = is_CGM
46+
# internal definition
47+
self.filters = [64, 128, 256, 512, 1024]
48+
self.cat_channels = self.filters[0]
49+
self.cat_blocks = 5
50+
self.up_channels = self.cat_channels * self.cat_blocks
51+
# layers
52+
self.encoder = Encoder(in_channels, self.filters, is_batchnorm)
53+
self.decoder = Decoder(self.filters, self.cat_channels, self.up_channels)
54+
if self.is_deepsup:
55+
self.deepsup = DeepSup(self.up_channels, self.filters, num_classes)
56+
if self.is_CGM:
57+
self.cls = nn.Sequential(nn.Dropout(p=0.5),
58+
nn.Conv2D(self.filters[4], 2, 1),
59+
nn.AdaptiveMaxPool2D(1),
60+
nn.Sigmoid())
61+
else:
62+
self.outconv1 = nn.Conv2D(self.up_channels, num_classes, 3, padding=1)
63+
# initialise weights
64+
for sublayer in self.sublayers ():
65+
if isinstance(sublayer, nn.Conv2D):
66+
kaiming_normal_init(sublayer.weight)
67+
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
68+
kaiming_normal_init(sublayer.weight)
69+
70+
def dotProduct(self, seg, cls):
71+
B, N, H, W = seg.shape
72+
seg = seg.reshape((B, N, H * W))
73+
clssp = paddle.ones([1, N])
74+
ecls = (cls * clssp).reshape([B, N, 1])
75+
final = seg * ecls
76+
final = final.reshape((B, N, H, W))
77+
return final
78+
79+
def forward(self, inputs):
80+
hs = self.encoder(inputs)
81+
hds = self.decoder(hs)
82+
if self.is_deepsup:
83+
out = self.deepsup(hds)
84+
if self.is_CGM:
85+
# classification-guided module
86+
cls_branch = self.cls(hds[-1]).squeeze(3).squeeze(2) # (B,N,1,1)->(B,N)
87+
cls_branch_max = cls_branch.argmax(axis=1)
88+
cls_branch_max = cls_branch_max.reshape((-1, 1)).astype('float')
89+
out = [self.dotProduct(d, cls_branch_max) for d in out]
90+
else:
91+
out = [self.outconv1(hds[0])] # d1->320*320*num_classes
92+
return out
93+
94+
95+
class Encoder(nn.Layer):
96+
def __init__(self, in_channels, filters, is_batchnorm):
97+
super(Encoder, self).__init__()
98+
self.conv1 = UnetConv2D(in_channels, filters[0], is_batchnorm)
99+
self.poolconv2 = MaxPoolConv2D(filters[0], filters[1], is_batchnorm)
100+
self.poolconv3 = MaxPoolConv2D(filters[1], filters[2], is_batchnorm)
101+
self.poolconv4 = MaxPoolConv2D(filters[2], filters[3], is_batchnorm)
102+
self.poolconv5 = MaxPoolConv2D(filters[3], filters[4], is_batchnorm)
103+
104+
def forward(self, inputs):
105+
h1 = self.conv1(inputs) # h1->320*320*64
106+
h2 = self.poolconv2(h1) # h2->160*160*128
107+
h3 = self.poolconv3(h2) # h3->80*80*256
108+
h4 = self.poolconv4(h3) # h4->40*40*512
109+
hd5 = self.poolconv5(h4) # h5->20*20*1024
110+
return [h1, h2, h3, h4, hd5]
111+
112+
113+
class Decoder(nn.Layer):
114+
def __init__(self, filters, cat_channels, up_channels):
115+
super(Decoder, self).__init__()
116+
'''stage 4d'''
117+
# h1->320*320, hd4->40*40, Pooling 8 times
118+
self.h1_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)
119+
self.h1_PT_hd4_cbr = ConvBnReLU2D(filters[0], cat_channels)
120+
# h2->160*160, hd4->40*40, Pooling 4 times
121+
self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)
122+
self.h2_PT_hd4_cbr = ConvBnReLU2D(filters[1], cat_channels)
123+
# h3->80*80, hd4->40*40, Pooling 2 times
124+
self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)
125+
self.h3_PT_hd4_cbr = ConvBnReLU2D(filters[2], cat_channels)
126+
# h4->40*40, hd4->40*40, Concatenation
127+
self.h4_Cat_hd4_cbr = ConvBnReLU2D(filters[3], cat_channels)
128+
# hd5->20*20, hd4->40*40, Upsample 2 times
129+
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
130+
self.hd5_UT_hd4_cbr = ConvBnReLU2D(filters[4], cat_channels)
131+
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
132+
self.cbr4d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
133+
'''stage 3d'''
134+
# h1->320*320, hd3->80*80, Pooling 4 times
135+
self.h1_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)
136+
self.h1_PT_hd3_cbr = ConvBnReLU2D(filters[0], cat_channels)
137+
# h2->160*160, hd3->80*80, Pooling 2 times
138+
self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)
139+
self.h2_PT_hd3_cbr = ConvBnReLU2D(filters[1], cat_channels)
140+
# h3->80*80, hd3->80*80, Concatenation
141+
self.h3_Cat_hd3_cbr = ConvBnReLU2D(filters[2], cat_channels)
142+
# hd4->40*40, hd4->80*80, Upsample 2 times
143+
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
144+
self.hd4_UT_hd3_cbr = ConvBnReLU2D(up_channels, cat_channels)
145+
# hd5->20*20, hd4->80*80, Upsample 4 times
146+
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
147+
self.hd5_UT_hd3_cbr = ConvBnReLU2D(filters[4], cat_channels)
148+
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
149+
self.cbr3d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
150+
'''stage 2d '''
151+
# h1->320*320, hd2->160*160, Pooling 2 times
152+
self.h1_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)
153+
self.h1_PT_hd2_cbr = ConvBnReLU2D(filters[0], cat_channels)
154+
# h2->160*160, hd2->160*160, Concatenation
155+
self.h2_Cat_hd2_cbr = ConvBnReLU2D(filters[1], cat_channels)
156+
# hd3->80*80, hd2->160*160, Upsample 2 times
157+
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
158+
self.hd3_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
159+
# hd4->40*40, hd2->160*160, Upsample 4 times
160+
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
161+
self.hd4_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
162+
# hd5->20*20, hd2->160*160, Upsample 8 times
163+
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
164+
self.hd5_UT_hd2_cbr = ConvBnReLU2D(filters[4], cat_channels)
165+
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
166+
self.cbr2d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
167+
'''stage 1d'''
168+
# h1->320*320, hd1->320*320, Concatenation
169+
self.h1_Cat_hd1_cbr = ConvBnReLU2D(filters[0], cat_channels)
170+
# hd2->160*160, hd1->320*320, Upsample 2 times
171+
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
172+
self.hd2_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
173+
# hd3->80*80, hd1->320*320, Upsample 4 times
174+
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
175+
self.hd3_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
176+
# hd4->40*40, hd1->320*320, Upsample 8 times
177+
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
178+
self.hd4_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
179+
# hd5->20*20, hd1->320*320, Upsample 16 times
180+
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14
181+
self.hd5_UT_hd1_cbr = ConvBnReLU2D(filters[4], cat_channels)
182+
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
183+
self.cbr1d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
184+
185+
def forward(self, inputs):
186+
h1, h2, h3, h4, hd5 = inputs
187+
h1_PT_hd4 = self.h1_PT_hd4_cbr(self.h1_PT_hd4(h1))
188+
h2_PT_hd4 = self.h2_PT_hd4_cbr(self.h2_PT_hd4(h2))
189+
h3_PT_hd4 = self.h3_PT_hd4_cbr(self.h3_PT_hd4(h3))
190+
h4_Cat_hd4 = self.h4_Cat_hd4_cbr(h4)
191+
hd5_UT_hd4 = self.hd5_UT_hd4_cbr(self.hd5_UT_hd4(hd5))
192+
# hd4->40*40*up_channels
193+
hd4 = self.cbr4d_1(paddle.concat([h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1))
194+
h1_PT_hd3 = self.h1_PT_hd3_cbr(self.h1_PT_hd3(h1))
195+
h2_PT_hd3 = self.h2_PT_hd3_cbr(self.h2_PT_hd3(h2))
196+
h3_Cat_hd3 = self.h3_Cat_hd3_cbr(h3)
197+
hd4_UT_hd3 = self.hd4_UT_hd3_cbr(self.hd4_UT_hd3(hd4))
198+
hd5_UT_hd3 = self.hd5_UT_hd3_cbr(self.hd5_UT_hd3(hd5))
199+
# hd3->80*80*up_channels
200+
hd3 = self.cbr3d_1(paddle.concat([h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1))
201+
h1_PT_hd2 = self.h1_PT_hd2_cbr(self.h1_PT_hd2(h1))
202+
h2_Cat_hd2 = self.h2_Cat_hd2_cbr(h2)
203+
hd3_UT_hd2 = self.hd3_UT_hd2_cbr(self.hd3_UT_hd2(hd3))
204+
hd4_UT_hd2 = self.hd4_UT_hd2_cbr(self.hd4_UT_hd2(hd4))
205+
hd5_UT_hd2 = self.hd5_UT_hd2_cbr(self.hd5_UT_hd2(hd5))
206+
# hd2->160*160*up_channels
207+
hd2 = self.cbr2d_1(paddle.concat([h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2], 1))
208+
h1_Cat_hd1 = self.h1_Cat_hd1_cbr(h1)
209+
hd2_UT_hd1 = self.hd2_UT_hd1_cbr(self.hd2_UT_hd1(hd2))
210+
hd3_UT_hd1 = self.hd3_UT_hd1_cbr(self.hd3_UT_hd1(hd3))
211+
hd4_UT_hd1 = self.hd4_UT_hd1_cbr(self.hd4_UT_hd1(hd4))
212+
hd5_UT_hd1 = self.hd5_UT_hd1_cbr(self.hd5_UT_hd1(hd5))
213+
# hd1->320*320*up_channels
214+
hd1 = self.cbr1d_1(paddle.concat([h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1], 1))
215+
return [hd1, hd2, hd3, hd4, hd5]
216+
217+
218+
class DeepSup(nn.Layer):
219+
def __init__(self, up_channels, filters, num_classes):
220+
super(DeepSup, self).__init__()
221+
self.convup5 = ConvUp2D(filters[4], num_classes, 16)
222+
self.convup4 = ConvUp2D(up_channels, num_classes, 8)
223+
self.convup3 = ConvUp2D(up_channels, num_classes, 4)
224+
self.convup2 = ConvUp2D(up_channels, num_classes, 2)
225+
self.outconv1 = nn.Conv2D(up_channels, num_classes, 3, padding=1)
226+
227+
def forward(self, inputs):
228+
hd1, hd2, hd3, hd4, hd5 = inputs
229+
d5 = self.convup5(hd5) # 16->256
230+
d4 = self.convup4(hd4) # 32->256
231+
d3 = self.convup3(hd3) # 64->256
232+
d2 = self.convup2(hd2) # 128->256
233+
d1 = self.outconv1(hd1) # 256
234+
return [d1, d2, d3, d4, d5]
235+
236+
237+
class ConvBnReLU2D(nn.Sequential):
238+
def __init__(self, in_channels, out_channels):
239+
super(ConvBnReLU2D, self).__init__(
240+
nn.Conv2D(in_channels, out_channels, 3, padding=1),
241+
nn.BatchNorm(out_channels),
242+
nn.ReLU()
243+
)
244+
245+
246+
class ConvUp2D(nn.Sequential):
247+
def __init__(self, in_channels, out_channels, scale_factor):
248+
super(ConvUp2D, self).__init__(
249+
nn.Conv2D(in_channels, out_channels, 3, padding=1),
250+
nn.Upsample(scale_factor=scale_factor, mode='bilinear')
251+
)
252+
253+
254+
class MaxPoolConv2D(nn.Sequential):
255+
def __init__(self, in_channels, out_channels, is_batchnorm):
256+
super(MaxPoolConv2D, self).__init__(
257+
nn.MaxPool2D(kernel_size=2),
258+
UnetConv2D(in_channels, out_channels, is_batchnorm)
259+
)
260+
261+
262+
class UnetConv2D(nn.Layer):
263+
def __init__(self, in_channels, out_channels, is_batchnorm, num_conv=2, kernel_size=3, stride=1, padding=1):
264+
super(UnetConv2D, self).__init__()
265+
self.num_conv = num_conv
266+
for i in range(num_conv):
267+
conv = (nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
268+
nn.BatchNorm(out_channels),
269+
nn.ReLU()) \
270+
if is_batchnorm else \
271+
nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
272+
nn.ReLU()))
273+
setattr(self, 'conv%d' % (i + 1), conv)
274+
in_channels = out_channels
275+
# initialise the blocks
276+
for children in self.children():
277+
children.weight_attr = paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
278+
children.bias_attr = paddle.framework.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal)
279+
280+
def forward(self, inputs):
281+
x = inputs
282+
for i in range(self.num_conv):
283+
conv = getattr(self, 'conv%d' % (i + 1))
284+
x = conv(x)
285+
return x

0 commit comments

Comments
 (0)