@@ -51,23 +51,99 @@ def forward(self, *features):
51
51
return super ().forward (features [- 1 ])
52
52
53
53
54
+ class DeepLabV3PlusDecoder (nn .Module ):
55
+ def __init__ (
56
+ self ,
57
+ encoder_channels ,
58
+ out_channels = 256 ,
59
+ atrous_rates = (12 , 24 , 36 ),
60
+ output_stride = 16 ,
61
+ ):
62
+ super ().__init__ ()
63
+ if output_stride not in {8 , 16 }:
64
+ raise ValueError ("Output stride should be 8 or 16, got {}." .format (output_stride ))
65
+
66
+ self .out_channels = out_channels
67
+ self .output_stride = output_stride
68
+
69
+ self .aspp = nn .Sequential (
70
+ ASPP (encoder_channels [- 1 ], out_channels , atrous_rates , separable = True ),
71
+ SeparableConv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 , bias = False ),
72
+ nn .BatchNorm2d (out_channels ),
73
+ nn .ReLU (),
74
+ )
75
+
76
+ scale_factor = 2 if output_stride == 8 else 4
77
+ self .up = nn .UpsamplingBilinear2d (scale_factor = scale_factor )
78
+
79
+ highres_in_channels = encoder_channels [- 4 ]
80
+ highres_out_channels = 48 # proposed by authors of paper
81
+ self .block1 = nn .Sequential (
82
+ nn .Conv2d (highres_in_channels , highres_out_channels , kernel_size = 1 , bias = False ),
83
+ nn .BatchNorm2d (highres_out_channels ),
84
+ nn .ReLU (),
85
+ )
86
+ self .block2 = nn .Sequential (
87
+ SeparableConv2d (
88
+ highres_out_channels + out_channels ,
89
+ out_channels ,
90
+ kernel_size = 3 ,
91
+ padding = 1 ,
92
+ bias = False ,
93
+ ),
94
+ nn .BatchNorm2d (out_channels ),
95
+ nn .ReLU (),
96
+ )
97
+
98
+ def forward (self , * features ):
99
+ aspp_features = self .aspp (features [- 1 ])
100
+ aspp_features = self .up (aspp_features )
101
+ high_res_features = self .block1 (features [- 4 ])
102
+ concat_features = torch .cat ([aspp_features , high_res_features ], dim = 1 )
103
+ fused_features = self .block2 (concat_features )
104
+ return fused_features
105
+
106
+
54
107
class ASPPConv (nn .Sequential ):
55
108
def __init__ (self , in_channels , out_channels , dilation ):
56
- modules = [
57
- nn .Conv2d (in_channels , out_channels , 3 , padding = dilation , dilation = dilation , bias = False ),
109
+ super ().__init__ (
110
+ nn .Conv2d (
111
+ in_channels ,
112
+ out_channels ,
113
+ kernel_size = 3 ,
114
+ padding = dilation ,
115
+ dilation = dilation ,
116
+ bias = False ,
117
+ ),
118
+ nn .BatchNorm2d (out_channels ),
119
+ nn .ReLU (),
120
+ )
121
+
122
+
123
+ class ASPPSeparableConv (nn .Sequential ):
124
+ def __init__ (self , in_channels , out_channels , dilation ):
125
+ super ().__init__ (
126
+ SeparableConv2d (
127
+ in_channels ,
128
+ out_channels ,
129
+ kernel_size = 3 ,
130
+ padding = dilation ,
131
+ dilation = dilation ,
132
+ bias = False ,
133
+ ),
58
134
nn .BatchNorm2d (out_channels ),
59
- nn .ReLU ()
60
- ]
61
- super (ASPPConv , self ).__init__ (* modules )
135
+ nn .ReLU (),
136
+ )
62
137
63
138
64
139
class ASPPPooling (nn .Sequential ):
65
140
def __init__ (self , in_channels , out_channels ):
66
- super (ASPPPooling , self ).__init__ (
141
+ super ().__init__ (
67
142
nn .AdaptiveAvgPool2d (1 ),
68
- nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
143
+ nn .Conv2d (in_channels , out_channels , kernel_size = 1 , bias = False ),
69
144
nn .BatchNorm2d (out_channels ),
70
- nn .ReLU ())
145
+ nn .ReLU (),
146
+ )
71
147
72
148
def forward (self , x ):
73
149
size = x .shape [- 2 :]
@@ -77,31 +153,68 @@ def forward(self, x):
77
153
78
154
79
155
class ASPP (nn .Module ):
80
- def __init__ (self , in_channels , out_channels , atrous_rates ):
156
+ def __init__ (self , in_channels , out_channels , atrous_rates , separable = False ):
81
157
super (ASPP , self ).__init__ ()
82
158
modules = []
83
- modules .append (nn .Sequential (
84
- nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
85
- nn .BatchNorm2d (out_channels ),
86
- nn .ReLU ()))
159
+ modules .append (
160
+ nn .Sequential (
161
+ nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
162
+ nn .BatchNorm2d (out_channels ),
163
+ nn .ReLU (),
164
+ )
165
+ )
87
166
88
167
rate1 , rate2 , rate3 = tuple (atrous_rates )
89
- modules .append (ASPPConv (in_channels , out_channels , rate1 ))
90
- modules .append (ASPPConv (in_channels , out_channels , rate2 ))
91
- modules .append (ASPPConv (in_channels , out_channels , rate3 ))
168
+ ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
169
+
170
+ modules .append (ASPPConvModule (in_channels , out_channels , rate1 ))
171
+ modules .append (ASPPConvModule (in_channels , out_channels , rate2 ))
172
+ modules .append (ASPPConvModule (in_channels , out_channels , rate3 ))
92
173
modules .append (ASPPPooling (in_channels , out_channels ))
93
174
94
175
self .convs = nn .ModuleList (modules )
95
176
96
177
self .project = nn .Sequential (
97
- nn .Conv2d (5 * out_channels , out_channels , 1 , bias = False ),
178
+ nn .Conv2d (5 * out_channels , out_channels , kernel_size = 1 , bias = False ),
98
179
nn .BatchNorm2d (out_channels ),
99
180
nn .ReLU (),
100
- nn .Dropout (0.5 ))
181
+ nn .Dropout (0.5 ),
182
+ )
101
183
102
184
def forward (self , x ):
103
185
res = []
104
186
for conv in self .convs :
105
187
res .append (conv (x ))
106
188
res = torch .cat (res , dim = 1 )
107
189
return self .project (res )
190
+
191
+
192
+ class SeparableConv2d (nn .Sequential ):
193
+
194
+ def __init__ (
195
+ self ,
196
+ in_channels ,
197
+ out_channels ,
198
+ kernel_size ,
199
+ stride = 1 ,
200
+ padding = 0 ,
201
+ dilation = 1 ,
202
+ bias = True ,
203
+ ):
204
+ dephtwise_conv = nn .Conv2d (
205
+ in_channels ,
206
+ in_channels ,
207
+ kernel_size ,
208
+ stride = stride ,
209
+ padding = padding ,
210
+ dilation = dilation ,
211
+ groups = in_channels ,
212
+ bias = False ,
213
+ )
214
+ pointwise_conv = nn .Conv2d (
215
+ in_channels ,
216
+ out_channels ,
217
+ kernel_size = 1 ,
218
+ bias = bias ,
219
+ )
220
+ super ().__init__ (dephtwise_conv , pointwise_conv )
0 commit comments