Skip to content

Commit 0a0dbd6

Browse files
committed
Initial commit
0 parents  commit 0a0dbd6

File tree

14 files changed

+515
-0
lines changed

14 files changed

+515
-0
lines changed

Diff for: .gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# dirs
2+
.idea/*
3+
__pycache__/
4+
5+
# files
6+
*.pyc
7+
*.pyx

Diff for: __init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from segmentation_models_pytorch import *

Diff for: segmentation_models_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .unet import Unet

Diff for: segmentation_models_pytorch/base/__init__.py

Whitespace-only changes.

Diff for: segmentation_models_pytorch/base/model.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch.nn as nn
2+
3+
class Model(nn.Module):
4+
5+
def __init__(self):
6+
super().__init__()
7+
8+
def initialize(self):
9+
for m in self.modules():
10+
if isinstance(m, nn.Conv2d):
11+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
12+
elif isinstance(m, nn.BatchNorm2d):
13+
nn.init.constant_(m.weight, 1)
14+
nn.init.constant_(m.bias, 0)

Diff for: segmentation_models_pytorch/common/__init__.py

Whitespace-only changes.

Diff for: segmentation_models_pytorch/common/blocks.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch.nn as nn
2+
3+
4+
class Conv2dReLU(nn.Module):
5+
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
6+
stride=1, use_batchnorm=True, **batchnorm_params):
7+
8+
super().__init__()
9+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
10+
stride=stride, padding=padding, bias=not (use_batchnorm))
11+
if use_batchnorm:
12+
self.batchnorm = nn.BatchNorm2d(out_channels, **batchnorm_params)
13+
14+
self.activation = nn.ReLU(inplace=True)
15+
16+
def forward(self, x):
17+
x = self.conv(x)
18+
if hasattr(self, 'batchnorm'):
19+
x = self.batchnorm(x)
20+
x = self.activation(x)
21+
return x

Diff for: segmentation_models_pytorch/ecnoders/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch.utils.model_zoo as model_zoo
2+
3+
from .resnet import resnet_encoders
4+
from .dpn import dpn_encoders
5+
from .vgg import vgg_encoders
6+
7+
encoders = {}
8+
encoders.update(resnet_encoders)
9+
encoders.update(dpn_encoders)
10+
encoders.update(vgg_encoders)
11+
12+
13+
def get_encoder(name, pretrained=True):
14+
15+
Encoder = encoders[name]['encoder']
16+
encoder = Encoder(**encoders[name]['params'])
17+
encoder.out_shapes = encoders[name]['out_shapes']
18+
19+
if pretrained:
20+
encoder.load_state_dict(model_zoo.load_url(encoders[name]['url']))
21+
return encoder

Diff for: segmentation_models_pytorch/ecnoders/dpn.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import numpy as np
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from pretrainedmodels.models.dpn import DPN
7+
8+
9+
class DPNEncorder(DPN):
10+
11+
def __init__(self, feature_blocks, *args, **kwargs):
12+
super().__init__(*args, **kwargs)
13+
self.feature_blocks = np.cumsum(feature_blocks)
14+
15+
def forward(self, x):
16+
17+
features = []
18+
19+
input_block = self.features[0]
20+
21+
x = input_block.conv(x)
22+
x = input_block.bn(x)
23+
x = input_block.act(x)
24+
features.append(x)
25+
26+
x = input_block.pool(x)
27+
28+
for i, module in enumerate(self.features[1:], 1):
29+
30+
x = module(x)
31+
if i in self.feature_blocks:
32+
features.append(x)
33+
34+
out_features = [
35+
features[4],
36+
F.relu(torch.cat(features[3], dim=1), inplace=True),
37+
F.relu(torch.cat(features[2], dim=1), inplace=True),
38+
F.relu(torch.cat(features[1], dim=1), inplace=True),
39+
features[0],
40+
]
41+
42+
shapes = [f.shape[1] for f in out_features]
43+
print(tuple(shapes))
44+
45+
return out_features
46+
47+
48+
dpn_encoders = {
49+
'dpn68': {
50+
'encoder': DPNEncorder,
51+
'out_shapes': (832, 704, 320, 144, 10),
52+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-4af7d88d2.pth',
53+
'params': {
54+
'feature_blocks': (3, 4, 12, 4),
55+
'groups': 32,
56+
'inc_sec': (16, 32, 32, 64),
57+
'k_r': 128,
58+
'k_sec': (3, 4, 12, 3),
59+
'num_classes': 1000,
60+
'num_init_features': 10,
61+
'small': True,
62+
'test_time_pool': True
63+
},
64+
},
65+
66+
'dpn68b': {
67+
'encoder': DPNEncorder,
68+
'out_shapes': (832, 704, 320, 144, 10),
69+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-363ab9c19.pth',
70+
'params': {
71+
'feature_blocks': (3, 4, 12, 4),
72+
'b': True,
73+
'groups': 32,
74+
'inc_sec': (16, 32, 32, 64),
75+
'k_r': 128,
76+
'k_sec': (3, 4, 12, 3),
77+
'num_classes': 1000,
78+
'num_init_features': 10,
79+
'small': True,
80+
'test_time_pool': True,
81+
},
82+
},
83+
84+
'dpn92': {
85+
'encoder': DPNEncorder,
86+
'out_shapes': (2688, 1552, 704, 336, 64),
87+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-fda993c95.pth',
88+
'params': {
89+
'feature_blocks': (3, 4, 20, 4),
90+
'groups': 32,
91+
'inc_sec': (16, 32, 24, 128),
92+
'k_r': 96,
93+
'k_sec': (3, 4, 20, 3),
94+
'num_classes': 1000,
95+
'num_init_features': 64,
96+
'test_time_pool': True
97+
},
98+
},
99+
100+
'dpn98': {
101+
'encoder': DPNEncorder,
102+
'out_shapes': (2688, 1728, 768, 336, 96),
103+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-722954780.pth',
104+
'params': {
105+
'feature_blocks': (3, 6, 20, 4),
106+
'groups': 40,
107+
'inc_sec': (16, 32, 32, 128),
108+
'k_r': 160,
109+
'k_sec': (3, 6, 20, 3),
110+
'num_classes': 1000,
111+
'num_init_features': 96,
112+
'test_time_pool': True,
113+
},
114+
},
115+
116+
'dpn107': {
117+
'encoder': DPNEncorder,
118+
'out_shapes': (2688, 2432, 1152, 376, 128),
119+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-b7f9f4cc9.pth',
120+
'params': {
121+
'feature_blocks': (4, 8, 20, 4),
122+
'groups': 50,
123+
'inc_sec': (20, 64, 64, 128),
124+
'k_r': 200,
125+
'k_sec': (4, 8, 20, 3),
126+
'num_classes': 1000,
127+
'num_init_features': 128,
128+
'test_time_pool': True
129+
},
130+
},
131+
132+
'dpn131': {
133+
'encoder': DPNEncorder,
134+
'out_shapes': (2688, 1984, 832, 352, 128),
135+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-7af84be88.pth',
136+
'params': {
137+
'feature_blocks': (4, 8, 28, 4),
138+
'groups': 40,
139+
'inc_sec': (16, 32, 32, 128),
140+
'k_r': 160,
141+
'k_sec': (4, 8, 28, 3),
142+
'num_classes': 1000,
143+
'num_init_features': 128,
144+
'test_time_pool': True
145+
},
146+
},
147+
148+
}

Diff for: segmentation_models_pytorch/ecnoders/resnet.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from torchvision.models.resnet import ResNet
2+
from torchvision.models.resnet import BasicBlock
3+
from torchvision.models.resnet import Bottleneck
4+
from torchvision.models.resnet import model_urls
5+
6+
7+
class ResNetEncoder(ResNet):
8+
9+
def forward(self, x):
10+
x0 = self.conv1(x)
11+
x0 = self.bn1(x0)
12+
x0 = self.relu(x0)
13+
14+
x1 = self.maxpool(x0)
15+
x1 = self.layer1(x1)
16+
17+
x2 = self.layer2(x1)
18+
x3 = self.layer3(x2)
19+
x4 = self.layer4(x3)
20+
21+
return [x4, x3, x2, x1, x0]
22+
23+
24+
resnet_encoders = {
25+
'resnet18': {
26+
'encoder': ResNetEncoder,
27+
'url': model_urls['resnet18'],
28+
'out_shapes': (512, 256, 128, 64, 64),
29+
'params': {
30+
'block': BasicBlock,
31+
'layers': [2, 2, 2, 2],
32+
},
33+
},
34+
35+
'resnet34': {
36+
'encoder': ResNetEncoder,
37+
'url': model_urls['resnet34'],
38+
'out_shapes': (512, 256, 128, 64, 64),
39+
'params': {
40+
'block': BasicBlock,
41+
'layers': [3, 4, 6, 3],
42+
},
43+
},
44+
45+
'resnet50': {
46+
'encoder': ResNetEncoder,
47+
'url': model_urls['resnet50'],
48+
'out_shapes': (2048, 1024, 512, 256, 64),
49+
'params': {
50+
'block': Bottleneck,
51+
'layers': [3, 4, 6, 3],
52+
},
53+
},
54+
55+
'resnet101': {
56+
'encoder': ResNetEncoder,
57+
'url': model_urls['resnet101'],
58+
'out_shapes': (2048, 1024, 512, 256, 64),
59+
'params': {
60+
'block': Bottleneck,
61+
'layers': [3, 4, 23, 3],
62+
},
63+
},
64+
65+
'resnet152': {
66+
'encoder': ResNetEncoder,
67+
'url': model_urls['resnet152'],
68+
'out_shapes': (2048, 1024, 512, 256, 64),
69+
'params': {
70+
'block': Bottleneck,
71+
'layers': [3, 8, 36, 3],
72+
},
73+
},
74+
}

Diff for: segmentation_models_pytorch/ecnoders/vgg.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch.nn as nn
2+
from torchvision.models.vgg import VGG
3+
from torchvision.models.vgg import make_layers
4+
from torchvision.models.vgg import cfg
5+
from torchvision.models.vgg import model_urls
6+
7+
8+
class VGGEncoder(VGG):
9+
10+
def forward(self, x):
11+
12+
features = []
13+
14+
for module in self.features:
15+
if isinstance(module, nn.MaxPool2d):
16+
features.append(x)
17+
18+
x = module(x)
19+
20+
features.append(x)
21+
22+
return features[::-1]
23+
24+
25+
vgg_encoders = {
26+
27+
'vgg11': {
28+
'encoder': VGGEncoder,
29+
'out_shapes': (512, 512, 512, 256, 128),
30+
'url': model_urls['vgg11'],
31+
'params': {
32+
'features': make_layers(cfg['A'], batch_norm=False),
33+
},
34+
},
35+
36+
'vgg11_bn': {
37+
'encoder': VGGEncoder,
38+
'out_shapes': (512, 512, 512, 256, 128),
39+
'url': model_urls['vgg11_bn'],
40+
'params': {
41+
'features': make_layers(cfg['A'], batch_norm=True),
42+
},
43+
},
44+
45+
'vgg13': {
46+
'encoder': VGGEncoder,
47+
'out_shapes': (512, 512, 512, 256, 128),
48+
'url': model_urls['vgg13'],
49+
'params': {
50+
'features': make_layers(cfg['B'], batch_norm=False),
51+
},
52+
},
53+
54+
'vgg13_bn': {
55+
'encoder': VGGEncoder,
56+
'out_shapes': (512, 512, 512, 256, 128),
57+
'url': model_urls['vgg13_bn'],
58+
'params': {
59+
'features': make_layers(cfg['B'], batch_norm=True),
60+
},
61+
},
62+
63+
'vgg16': {
64+
'encoder': VGGEncoder,
65+
'out_shapes': (512, 512, 512, 256, 128),
66+
'url': model_urls['vgg16'],
67+
'params': {
68+
'features': make_layers(cfg['D'], batch_norm=False),
69+
},
70+
},
71+
72+
'vgg16_bn': {
73+
'encoder': VGGEncoder,
74+
'out_shapes': (512, 512, 512, 256, 128),
75+
'url': model_urls['vgg16_bn'],
76+
'params': {
77+
'features': make_layers(cfg['D'], batch_norm=True),
78+
},
79+
},
80+
81+
'vgg19': {
82+
'encoder': VGGEncoder,
83+
'out_shapes': (512, 512, 512, 256, 128),
84+
'url': model_urls['vgg19'],
85+
'params': {
86+
'features': make_layers(cfg['E'], batch_norm=False),
87+
},
88+
},
89+
90+
'vgg19_bn': {
91+
'encoder': VGGEncoder,
92+
'out_shapes': (512, 512, 512, 256, 128),
93+
'url': model_urls['vgg19_bn'],
94+
'params': {
95+
'features': make_layers(cfg['E'], batch_norm=True),
96+
},
97+
},
98+
}

0 commit comments

Comments
 (0)