1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle .nn as nn
17
+ import paddle .nn .functional as F
18
+
19
+ from paddleseg .cvlibs import manager
20
+ from paddleseg .models .layers .layer_libs import SyncBatchNorm
21
+ from paddleseg .cvlibs .param_init import kaiming_normal_init
22
+
23
+
24
+ @manager .MODELS .add_component
25
+ class UNet3Plus (nn .Layer ):
26
+ """
27
+ The UNet3+ implementation based on PaddlePaddle.
28
+
29
+ The original article refers to
30
+ Huang H , Lin L , Tong R , et al. "UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation"
31
+ (https://arxiv.org/abs/2004.08790).
32
+
33
+ Args:
34
+ in_channels (int, optional): The channel number of input image. Default: 3.
35
+ num_classes (int, optional): The unique number of target classes. Default: 2.
36
+ is_batchnorm (bool, optional): Use batchnorm after conv or not. Default: True.
37
+ is_deepsup (bool, optional): Use deep supervision or not. Default: False.
38
+ is_CGM (bool, optional): Use classification-guided module or not.
39
+ If True, is_deepsup must be True. Default: False.
40
+ """
41
+ def __init__ (self , in_channels = 3 , num_classes = 2 , is_batchnorm = True , is_deepsup = False , is_CGM = False ):
42
+ super (UNet3Plus , self ).__init__ ()
43
+ # parameters
44
+ self .is_deepsup = True if is_CGM else is_deepsup
45
+ self .is_CGM = is_CGM
46
+ # internal definition
47
+ self .filters = [64 , 128 , 256 , 512 , 1024 ]
48
+ self .cat_channels = self .filters [0 ]
49
+ self .cat_blocks = 5
50
+ self .up_channels = self .cat_channels * self .cat_blocks
51
+ # layers
52
+ self .encoder = Encoder (in_channels , self .filters , is_batchnorm )
53
+ self .decoder = Decoder (self .filters , self .cat_channels , self .up_channels )
54
+ if self .is_deepsup :
55
+ self .deepsup = DeepSup (self .up_channels , self .filters , num_classes )
56
+ if self .is_CGM :
57
+ self .cls = nn .Sequential (nn .Dropout (p = 0.5 ),
58
+ nn .Conv2D (self .filters [4 ], 2 , 1 ),
59
+ nn .AdaptiveMaxPool2D (1 ),
60
+ nn .Sigmoid ())
61
+ else :
62
+ self .outconv1 = nn .Conv2D (self .up_channels , num_classes , 3 , padding = 1 )
63
+ # initialise weights
64
+ for sublayer in self .sublayers ():
65
+ if isinstance (sublayer , nn .Conv2D ):
66
+ kaiming_normal_init (sublayer .weight )
67
+ elif isinstance (sublayer , (nn .BatchNorm , nn .SyncBatchNorm )):
68
+ kaiming_normal_init (sublayer .weight )
69
+
70
+ def dotProduct (self , seg , cls ):
71
+ B , N , H , W = seg .shape
72
+ seg = seg .reshape ((B , N , H * W ))
73
+ clssp = paddle .ones ([1 , N ])
74
+ ecls = (cls * clssp ).reshape ([B , N , 1 ])
75
+ final = seg * ecls
76
+ final = final .reshape ((B , N , H , W ))
77
+ return final
78
+
79
+ def forward (self , inputs ):
80
+ hs = self .encoder (inputs )
81
+ hds = self .decoder (hs )
82
+ if self .is_deepsup :
83
+ out = self .deepsup (hds )
84
+ if self .is_CGM :
85
+ # classification-guided module
86
+ cls_branch = self .cls (hds [- 1 ]).squeeze (3 ).squeeze (2 ) # (B,N,1,1)->(B,N)
87
+ cls_branch_max = cls_branch .argmax (axis = 1 )
88
+ cls_branch_max = cls_branch_max .reshape ((- 1 , 1 )).astype ('float' )
89
+ out = [self .dotProduct (d , cls_branch_max ) for d in out ]
90
+ else :
91
+ out = [self .outconv1 (hds [0 ])] # d1->320*320*num_classes
92
+ return out
93
+
94
+
95
+ class Encoder (nn .Layer ):
96
+ def __init__ (self , in_channels , filters , is_batchnorm ):
97
+ super (Encoder , self ).__init__ ()
98
+ self .conv1 = UnetConv2D (in_channels , filters [0 ], is_batchnorm )
99
+ self .poolconv2 = MaxPoolConv2D (filters [0 ], filters [1 ], is_batchnorm )
100
+ self .poolconv3 = MaxPoolConv2D (filters [1 ], filters [2 ], is_batchnorm )
101
+ self .poolconv4 = MaxPoolConv2D (filters [2 ], filters [3 ], is_batchnorm )
102
+ self .poolconv5 = MaxPoolConv2D (filters [3 ], filters [4 ], is_batchnorm )
103
+
104
+ def forward (self , inputs ):
105
+ h1 = self .conv1 (inputs ) # h1->320*320*64
106
+ h2 = self .poolconv2 (h1 ) # h2->160*160*128
107
+ h3 = self .poolconv3 (h2 ) # h3->80*80*256
108
+ h4 = self .poolconv4 (h3 ) # h4->40*40*512
109
+ hd5 = self .poolconv5 (h4 ) # h5->20*20*1024
110
+ return [h1 , h2 , h3 , h4 , hd5 ]
111
+
112
+
113
+ class Decoder (nn .Layer ):
114
+ def __init__ (self , filters , cat_channels , up_channels ):
115
+ super (Decoder , self ).__init__ ()
116
+ '''stage 4d'''
117
+ # h1->320*320, hd4->40*40, Pooling 8 times
118
+ self .h1_PT_hd4 = nn .MaxPool2D (8 , 8 , ceil_mode = True )
119
+ self .h1_PT_hd4_cbr = ConvBnReLU2D (filters [0 ], cat_channels )
120
+ # h2->160*160, hd4->40*40, Pooling 4 times
121
+ self .h2_PT_hd4 = nn .MaxPool2D (4 , 4 , ceil_mode = True )
122
+ self .h2_PT_hd4_cbr = ConvBnReLU2D (filters [1 ], cat_channels )
123
+ # h3->80*80, hd4->40*40, Pooling 2 times
124
+ self .h3_PT_hd4 = nn .MaxPool2D (2 , 2 , ceil_mode = True )
125
+ self .h3_PT_hd4_cbr = ConvBnReLU2D (filters [2 ], cat_channels )
126
+ # h4->40*40, hd4->40*40, Concatenation
127
+ self .h4_Cat_hd4_cbr = ConvBnReLU2D (filters [3 ], cat_channels )
128
+ # hd5->20*20, hd4->40*40, Upsample 2 times
129
+ self .hd5_UT_hd4 = nn .Upsample (scale_factor = 2 , mode = 'bilinear' ) # 14*14
130
+ self .hd5_UT_hd4_cbr = ConvBnReLU2D (filters [4 ], cat_channels )
131
+ # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
132
+ self .cbr4d_1 = ConvBnReLU2D (up_channels , up_channels ) # 16
133
+ '''stage 3d'''
134
+ # h1->320*320, hd3->80*80, Pooling 4 times
135
+ self .h1_PT_hd3 = nn .MaxPool2D (4 , 4 , ceil_mode = True )
136
+ self .h1_PT_hd3_cbr = ConvBnReLU2D (filters [0 ], cat_channels )
137
+ # h2->160*160, hd3->80*80, Pooling 2 times
138
+ self .h2_PT_hd3 = nn .MaxPool2D (2 , 2 , ceil_mode = True )
139
+ self .h2_PT_hd3_cbr = ConvBnReLU2D (filters [1 ], cat_channels )
140
+ # h3->80*80, hd3->80*80, Concatenation
141
+ self .h3_Cat_hd3_cbr = ConvBnReLU2D (filters [2 ], cat_channels )
142
+ # hd4->40*40, hd4->80*80, Upsample 2 times
143
+ self .hd4_UT_hd3 = nn .Upsample (scale_factor = 2 , mode = 'bilinear' ) # 14*14
144
+ self .hd4_UT_hd3_cbr = ConvBnReLU2D (up_channels , cat_channels )
145
+ # hd5->20*20, hd4->80*80, Upsample 4 times
146
+ self .hd5_UT_hd3 = nn .Upsample (scale_factor = 4 , mode = 'bilinear' ) # 14*14
147
+ self .hd5_UT_hd3_cbr = ConvBnReLU2D (filters [4 ], cat_channels )
148
+ # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
149
+ self .cbr3d_1 = ConvBnReLU2D (up_channels , up_channels ) # 16
150
+ '''stage 2d '''
151
+ # h1->320*320, hd2->160*160, Pooling 2 times
152
+ self .h1_PT_hd2 = nn .MaxPool2D (2 , 2 , ceil_mode = True )
153
+ self .h1_PT_hd2_cbr = ConvBnReLU2D (filters [0 ], cat_channels )
154
+ # h2->160*160, hd2->160*160, Concatenation
155
+ self .h2_Cat_hd2_cbr = ConvBnReLU2D (filters [1 ], cat_channels )
156
+ # hd3->80*80, hd2->160*160, Upsample 2 times
157
+ self .hd3_UT_hd2 = nn .Upsample (scale_factor = 2 , mode = 'bilinear' ) # 14*14
158
+ self .hd3_UT_hd2_cbr = ConvBnReLU2D (up_channels , cat_channels )
159
+ # hd4->40*40, hd2->160*160, Upsample 4 times
160
+ self .hd4_UT_hd2 = nn .Upsample (scale_factor = 4 , mode = 'bilinear' ) # 14*14
161
+ self .hd4_UT_hd2_cbr = ConvBnReLU2D (up_channels , cat_channels )
162
+ # hd5->20*20, hd2->160*160, Upsample 8 times
163
+ self .hd5_UT_hd2 = nn .Upsample (scale_factor = 8 , mode = 'bilinear' ) # 14*14
164
+ self .hd5_UT_hd2_cbr = ConvBnReLU2D (filters [4 ], cat_channels )
165
+ # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
166
+ self .cbr2d_1 = ConvBnReLU2D (up_channels , up_channels ) # 16
167
+ '''stage 1d'''
168
+ # h1->320*320, hd1->320*320, Concatenation
169
+ self .h1_Cat_hd1_cbr = ConvBnReLU2D (filters [0 ], cat_channels )
170
+ # hd2->160*160, hd1->320*320, Upsample 2 times
171
+ self .hd2_UT_hd1 = nn .Upsample (scale_factor = 2 , mode = 'bilinear' ) # 14*14
172
+ self .hd2_UT_hd1_cbr = ConvBnReLU2D (up_channels , cat_channels )
173
+ # hd3->80*80, hd1->320*320, Upsample 4 times
174
+ self .hd3_UT_hd1 = nn .Upsample (scale_factor = 4 , mode = 'bilinear' ) # 14*14
175
+ self .hd3_UT_hd1_cbr = ConvBnReLU2D (up_channels , cat_channels )
176
+ # hd4->40*40, hd1->320*320, Upsample 8 times
177
+ self .hd4_UT_hd1 = nn .Upsample (scale_factor = 8 , mode = 'bilinear' ) # 14*14
178
+ self .hd4_UT_hd1_cbr = ConvBnReLU2D (up_channels , cat_channels )
179
+ # hd5->20*20, hd1->320*320, Upsample 16 times
180
+ self .hd5_UT_hd1 = nn .Upsample (scale_factor = 16 , mode = 'bilinear' ) # 14*14
181
+ self .hd5_UT_hd1_cbr = ConvBnReLU2D (filters [4 ], cat_channels )
182
+ # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
183
+ self .cbr1d_1 = ConvBnReLU2D (up_channels , up_channels ) # 16
184
+
185
+ def forward (self , inputs ):
186
+ h1 , h2 , h3 , h4 , hd5 = inputs
187
+ h1_PT_hd4 = self .h1_PT_hd4_cbr (self .h1_PT_hd4 (h1 ))
188
+ h2_PT_hd4 = self .h2_PT_hd4_cbr (self .h2_PT_hd4 (h2 ))
189
+ h3_PT_hd4 = self .h3_PT_hd4_cbr (self .h3_PT_hd4 (h3 ))
190
+ h4_Cat_hd4 = self .h4_Cat_hd4_cbr (h4 )
191
+ hd5_UT_hd4 = self .hd5_UT_hd4_cbr (self .hd5_UT_hd4 (hd5 ))
192
+ # hd4->40*40*up_channels
193
+ hd4 = self .cbr4d_1 (paddle .concat ([h1_PT_hd4 , h2_PT_hd4 , h3_PT_hd4 , h4_Cat_hd4 , hd5_UT_hd4 ], 1 ))
194
+ h1_PT_hd3 = self .h1_PT_hd3_cbr (self .h1_PT_hd3 (h1 ))
195
+ h2_PT_hd3 = self .h2_PT_hd3_cbr (self .h2_PT_hd3 (h2 ))
196
+ h3_Cat_hd3 = self .h3_Cat_hd3_cbr (h3 )
197
+ hd4_UT_hd3 = self .hd4_UT_hd3_cbr (self .hd4_UT_hd3 (hd4 ))
198
+ hd5_UT_hd3 = self .hd5_UT_hd3_cbr (self .hd5_UT_hd3 (hd5 ))
199
+ # hd3->80*80*up_channels
200
+ hd3 = self .cbr3d_1 (paddle .concat ([h1_PT_hd3 , h2_PT_hd3 , h3_Cat_hd3 , hd4_UT_hd3 , hd5_UT_hd3 ], 1 ))
201
+ h1_PT_hd2 = self .h1_PT_hd2_cbr (self .h1_PT_hd2 (h1 ))
202
+ h2_Cat_hd2 = self .h2_Cat_hd2_cbr (h2 )
203
+ hd3_UT_hd2 = self .hd3_UT_hd2_cbr (self .hd3_UT_hd2 (hd3 ))
204
+ hd4_UT_hd2 = self .hd4_UT_hd2_cbr (self .hd4_UT_hd2 (hd4 ))
205
+ hd5_UT_hd2 = self .hd5_UT_hd2_cbr (self .hd5_UT_hd2 (hd5 ))
206
+ # hd2->160*160*up_channels
207
+ hd2 = self .cbr2d_1 (paddle .concat ([h1_PT_hd2 , h2_Cat_hd2 , hd3_UT_hd2 , hd4_UT_hd2 , hd5_UT_hd2 ], 1 ))
208
+ h1_Cat_hd1 = self .h1_Cat_hd1_cbr (h1 )
209
+ hd2_UT_hd1 = self .hd2_UT_hd1_cbr (self .hd2_UT_hd1 (hd2 ))
210
+ hd3_UT_hd1 = self .hd3_UT_hd1_cbr (self .hd3_UT_hd1 (hd3 ))
211
+ hd4_UT_hd1 = self .hd4_UT_hd1_cbr (self .hd4_UT_hd1 (hd4 ))
212
+ hd5_UT_hd1 = self .hd5_UT_hd1_cbr (self .hd5_UT_hd1 (hd5 ))
213
+ # hd1->320*320*up_channels
214
+ hd1 = self .cbr1d_1 (paddle .concat ([h1_Cat_hd1 , hd2_UT_hd1 , hd3_UT_hd1 , hd4_UT_hd1 , hd5_UT_hd1 ], 1 ))
215
+ return [hd1 , hd2 , hd3 , hd4 , hd5 ]
216
+
217
+
218
+ class DeepSup (nn .Layer ):
219
+ def __init__ (self , up_channels , filters , num_classes ):
220
+ super (DeepSup , self ).__init__ ()
221
+ self .convup5 = ConvUp2D (filters [4 ], num_classes , 16 )
222
+ self .convup4 = ConvUp2D (up_channels , num_classes , 8 )
223
+ self .convup3 = ConvUp2D (up_channels , num_classes , 4 )
224
+ self .convup2 = ConvUp2D (up_channels , num_classes , 2 )
225
+ self .outconv1 = nn .Conv2D (up_channels , num_classes , 3 , padding = 1 )
226
+
227
+ def forward (self , inputs ):
228
+ hd1 , hd2 , hd3 , hd4 , hd5 = inputs
229
+ d5 = self .convup5 (hd5 ) # 16->256
230
+ d4 = self .convup4 (hd4 ) # 32->256
231
+ d3 = self .convup3 (hd3 ) # 64->256
232
+ d2 = self .convup2 (hd2 ) # 128->256
233
+ d1 = self .outconv1 (hd1 ) # 256
234
+ return [d1 , d2 , d3 , d4 , d5 ]
235
+
236
+
237
+ class ConvBnReLU2D (nn .Sequential ):
238
+ def __init__ (self , in_channels , out_channels ):
239
+ super (ConvBnReLU2D , self ).__init__ (
240
+ nn .Conv2D (in_channels , out_channels , 3 , padding = 1 ),
241
+ nn .BatchNorm (out_channels ),
242
+ nn .ReLU ()
243
+ )
244
+
245
+
246
+ class ConvUp2D (nn .Sequential ):
247
+ def __init__ (self , in_channels , out_channels , scale_factor ):
248
+ super (ConvUp2D , self ).__init__ (
249
+ nn .Conv2D (in_channels , out_channels , 3 , padding = 1 ),
250
+ nn .Upsample (scale_factor = scale_factor , mode = 'bilinear' )
251
+ )
252
+
253
+
254
+ class MaxPoolConv2D (nn .Sequential ):
255
+ def __init__ (self , in_channels , out_channels , is_batchnorm ):
256
+ super (MaxPoolConv2D , self ).__init__ (
257
+ nn .MaxPool2D (kernel_size = 2 ),
258
+ UnetConv2D (in_channels , out_channels , is_batchnorm )
259
+ )
260
+
261
+
262
+ class UnetConv2D (nn .Layer ):
263
+ def __init__ (self , in_channels , out_channels , is_batchnorm , num_conv = 2 , kernel_size = 3 , stride = 1 , padding = 1 ):
264
+ super (UnetConv2D , self ).__init__ ()
265
+ self .num_conv = num_conv
266
+ for i in range (num_conv ):
267
+ conv = (nn .Sequential (nn .Conv2D (in_channels , out_channels , kernel_size , stride , padding ),
268
+ nn .BatchNorm (out_channels ),
269
+ nn .ReLU ()) \
270
+ if is_batchnorm else \
271
+ nn .Sequential (nn .Conv2D (in_channels , out_channels , kernel_size , stride , padding ),
272
+ nn .ReLU ()))
273
+ setattr (self , 'conv%d' % (i + 1 ), conv )
274
+ in_channels = out_channels
275
+ # initialise the blocks
276
+ for children in self .children ():
277
+ children .weight_attr = paddle .framework .ParamAttr (initializer = paddle .nn .initializer .KaimingNormal )
278
+ children .bias_attr = paddle .framework .ParamAttr (initializer = paddle .nn .initializer .KaimingNormal )
279
+
280
+ def forward (self , inputs ):
281
+ x = inputs
282
+ for i in range (self .num_conv ):
283
+ conv = getattr (self , 'conv%d' % (i + 1 ))
284
+ x = conv (x )
285
+ return x
0 commit comments