Skip to content

Commit a242d53

Browse files
authored
Feature Xception encoder (#108)
* Add Xception encoder (#102)
1 parent ff7e02f commit a242d53

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

Diff for: README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The main features of this library are:
77

88
- High level API (just two lines to create neural network)
99
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
10-
- 45 available encoders for each architecture
10+
- 46 available encoders for each architecture
1111
- All encoders have pre-trained weights for faster and better convergence
1212

1313
### Table of content
@@ -111,6 +111,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
111111
|efficientnet-b6 |imagenet |40M |
112112
|efficientnet-b7 |imagenet |63M |
113113
|mobilenet_v2 |imagenet |2M |
114+
|xception |imagenet |22M |
114115

115116
### Models API <a name="api"></a>
116117

Diff for: segmentation_models_pytorch/encoders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .inceptionv4 import inceptionv4_encoders
1111
from .efficientnet import efficient_net_encoders
1212
from .mobilenet import mobilenet_encoders
13+
from .xception import xception_encoders
1314

1415

1516
from ._preprocessing import preprocess_input
@@ -24,6 +25,7 @@
2425
encoders.update(inceptionv4_encoders)
2526
encoders.update(efficient_net_encoders)
2627
encoders.update(mobilenet_encoders)
28+
encoders.update(xception_encoders)
2729

2830

2931
def get_encoder(name, in_channels=3, depth=5, weights=None):

Diff for: segmentation_models_pytorch/encoders/xception.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import re
2+
import torch.nn as nn
3+
4+
from pretrainedmodels.models.xception import pretrained_settings
5+
from pretrainedmodels.models.xception import Xception
6+
7+
from ._base import EncoderMixin
8+
9+
class XceptionEncoder(Xception, EncoderMixin):
10+
11+
def __init__(self, out_channels, *args, depth=5, **kwargs):
12+
super().__init__(*args, **kwargs)
13+
14+
self._out_channels = out_channels
15+
self._depth = depth
16+
self._in_channels = 3
17+
18+
#modify padding to maintain output shape
19+
self.conv1.padding = 1
20+
self.conv2.padding = 1
21+
22+
del self.fc
23+
24+
@staticmethod
25+
def _transition(x, transition_block):
26+
for module in transition_block:
27+
x = module(x)
28+
if isinstance(module, nn.ReLU):
29+
skip = x
30+
return x, skip
31+
32+
def forward(self, x):
33+
features = [x]
34+
35+
if self._depth > 0:
36+
x = self.conv1(x)
37+
x = self.bn1(x)
38+
x = self.relu(x)
39+
40+
x = self.conv2(x)
41+
x = self.bn2(x)
42+
x0 = self.relu(x)
43+
features.append(x0)
44+
45+
if self._depth > 1:
46+
x1 = self.block1(x0)
47+
features.append(x1)
48+
49+
if self._depth > 2:
50+
x2 = self.block2(x1)
51+
features.append(x2)
52+
53+
if self._depth > 3:
54+
x = self.block3(x2)
55+
x = self.block4(x)
56+
x = self.block5(x)
57+
x = self.block6(x)
58+
x = self.block7(x)
59+
x = self.block8(x)
60+
x = self.block9(x)
61+
x = self.block10(x)
62+
x3 = self.block11(x)
63+
features.append(x3)
64+
65+
if self._depth > 4:
66+
x = self.block12(x)
67+
68+
x = self.conv3(x)
69+
x = self.bn3(x)
70+
x = self.relu(x)
71+
72+
x = self.conv4(x)
73+
x4 = self.bn4(x)
74+
features.append(x4)
75+
76+
return features
77+
78+
79+
def load_state_dict(self, state_dict):
80+
pattern = re.compile(
81+
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
82+
for key in list(state_dict.keys()):
83+
res = pattern.match(key)
84+
if res:
85+
new_key = res.group(1) + res.group(2)
86+
state_dict[new_key] = state_dict[key]
87+
del state_dict[key]
88+
89+
# remove linear
90+
state_dict.pop('fc.bias')
91+
state_dict.pop('fc.weight')
92+
93+
super().load_state_dict(state_dict)
94+
95+
96+
xception_encoders = {
97+
'xception': {
98+
'encoder': XceptionEncoder,
99+
'pretrained_settings': pretrained_settings['xception'],
100+
'params': {
101+
'out_channels': (3, 64, 128, 256, 728, 2048),
102+
}
103+
},
104+
}
105+

0 commit comments

Comments
 (0)