Skip to content

Commit 5bd52c3

Browse files
committed
Initial commit
0 parents  commit 5bd52c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+11898
-0
lines changed

README.md

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Segmentation Models Pytroch 3D
2+
3+
Python library with Neural Networks for Volume (3D) Segmentation based on PyTorch.
4+
5+
This library is based on famous [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch) library for images. Most of the documentation can be used directly from there.
6+
7+
## Installation
8+
9+
* Type 1: `pip install segmentation_models_pytorch_3d`
10+
* Type 2: Copy `segmentation_models_pytorch_3d` folder from this repository in your project folder.
11+
12+
## Quick start
13+
14+
Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
15+
16+
```python
17+
import segmentation_models_pytorch_3d as smp
18+
import torch
19+
20+
model = smp.Unet(
21+
encoder_name="efficientnet-b0", # choose encoder, e.g. resnet34
22+
in_channels=1, # model input channels (1 for gray-scale volumes, 3 for RGB, etc.)
23+
classes=3, # model output channels (number of classes in your dataset)
24+
)
25+
26+
# Shape of input (B, C, H, W, D). B - batch size, C - channels, H - height, W - width, D - depth
27+
res = model(torch.randn(4, 1, 64, 64, 64))
28+
```
29+
30+
## Models
31+
32+
### Architectures
33+
34+
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
35+
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
36+
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)]
37+
- Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)]
38+
- FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)]
39+
- PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)]
40+
- PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)]
41+
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
42+
43+
### Encoders
44+
45+
The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (`encoder_name` and `encoder_weights` parameters).
46+
47+
<details>
48+
<summary style="margin-left: 25px;">ResNet</summary>
49+
<div style="margin-left: 25px;">
50+
51+
|Encoder |Weights |Params, M |
52+
|--------------------------------|:------------------------------:|:------------------------------:|
53+
|resnet18 |imagenet / ssl / swsl |11M |
54+
|resnet34 |imagenet |21M |
55+
|resnet50 |imagenet / ssl / swsl |23M |
56+
|resnet101 |imagenet |42M |
57+
|resnet152 |imagenet |58M |
58+
59+
</div>
60+
</details>
61+
62+
<details>
63+
<summary style="margin-left: 25px;">ResNeXt</summary>
64+
<div style="margin-left: 25px;">
65+
66+
|Encoder |Weights |Params, M |
67+
|--------------------------------|:------------------------------:|:------------------------------:|
68+
|resnext50_32x4d |imagenet / ssl / swsl |22M |
69+
|resnext101_32x4d |ssl / swsl |42M |
70+
|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M |
71+
|resnext101_32x16d |instagram / ssl / swsl |191M |
72+
|resnext101_32x32d |instagram |466M |
73+
|resnext101_32x48d |instagram |826M |
74+
75+
</div>
76+
</details>
77+
78+
<details>
79+
<summary style="margin-left: 25px;">SE-Net</summary>
80+
<div style="margin-left: 25px;">
81+
82+
|Encoder |Weights |Params, M |
83+
|--------------------------------|:------------------------------:|:------------------------------:|
84+
|senet154 |imagenet |113M |
85+
|se_resnet50 |imagenet |26M |
86+
|se_resnet101 |imagenet |47M |
87+
|se_resnet152 |imagenet |64M |
88+
|se_resnext50_32x4d |imagenet |25M |
89+
|se_resnext101_32x4d |imagenet |46M |
90+
91+
</div>
92+
</details>
93+
94+
<details>
95+
<summary style="margin-left: 25px;">DenseNet</summary>
96+
<div style="margin-left: 25px;">
97+
98+
|Encoder |Weights |Params, M |
99+
|--------------------------------|:------------------------------:|:------------------------------:|
100+
|densenet121 |imagenet |6M |
101+
|densenet169 |imagenet |12M |
102+
|densenet201 |imagenet |18M |
103+
|densenet161 |imagenet |26M |
104+
105+
</div>
106+
</details>
107+
108+
<details>
109+
<summary style="margin-left: 25px;">EfficientNet</summary>
110+
<div style="margin-left: 25px;">
111+
112+
|Encoder |Weights |Params, M |
113+
|--------------------------------|:------------------------------:|:------------------------------:|
114+
|efficientnet-b0 |imagenet |4M |
115+
|efficientnet-b1 |imagenet |6M |
116+
|efficientnet-b2 |imagenet |7M |
117+
|efficientnet-b3 |imagenet |10M |
118+
|efficientnet-b4 |imagenet |17M |
119+
|efficientnet-b5 |imagenet |28M |
120+
|efficientnet-b6 |imagenet |40M |
121+
|efficientnet-b7 |imagenet |63M |
122+
</div>
123+
</details>
124+
125+
<details>
126+
<summary style="margin-left: 25px;">DPN</summary>
127+
<div style="margin-left: 25px;">
128+
129+
|Encoder |Weights |Params, M |
130+
|--------------------------------|:------------------------------:|:------------------------------:|
131+
|dpn68 |imagenet |11M |
132+
|dpn68b |imagenet+5k |11M |
133+
|dpn92 |imagenet+5k |34M |
134+
|dpn98 |imagenet |58M |
135+
|dpn107 |imagenet+5k |84M |
136+
|dpn131 |imagenet |76M |
137+
138+
</div>
139+
</details>
140+
141+
<details>
142+
<summary style="margin-left: 25px;">VGG</summary>
143+
<div style="margin-left: 25px;">
144+
145+
|Encoder |Weights |Params, M |
146+
|--------------------------------|:------------------------------:|:------------------------------:|
147+
|vgg11 |imagenet |9M |
148+
|vgg11_bn |imagenet |9M |
149+
|vgg13 |imagenet |9M |
150+
|vgg13_bn |imagenet |9M |
151+
|vgg16 |imagenet |14M |
152+
|vgg16_bn |imagenet |14M |
153+
|vgg19 |imagenet |20M |
154+
|vgg19_bn |imagenet |20M |
155+
156+
</div>
157+
</details>
158+
159+
<details>
160+
<summary style="margin-left: 25px;">Mix Vision Transformer</summary>
161+
<div style="margin-left: 25px;">
162+
163+
Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Vision Transformer with Unet, FPN and others!
164+
165+
Limitations:
166+
167+
- encoder is **not** supported by Linknet, Unet++
168+
- encoder is supported by FPN only for encoder **depth = 5**
169+
170+
|Encoder |Weights |Params, M |
171+
|--------------------------------|:------------------------------:|:------------------------------:|
172+
|mit_b0 |imagenet |3M |
173+
|mit_b1 |imagenet |13M |
174+
|mit_b2 |imagenet |24M |
175+
|mit_b3 |imagenet |44M |
176+
|mit_b4 |imagenet |60M |
177+
|mit_b5 |imagenet |81M |
178+
179+
</div>
180+
</details>
181+
182+
<details>
183+
<summary style="margin-left: 25px;">MobileOne</summary>
184+
<div style="margin-left: 25px;">
185+
186+
Apple's "sub-one-ms" Backbone pretrained on Imagenet! Can be used with all decoders.
187+
188+
Note: In the official github repo the s0 variant has additional num_conv_branches, leading to more params than s1.
189+
190+
|Encoder |Weights |Params, M |
191+
|--------------------------------|:------------------------------:|:------------------------------:|
192+
|mobileone_s0 |imagenet |4.6M |
193+
|mobileone_s1 |imagenet |4.0M |
194+
|mobileone_s2 |imagenet |6.5M |
195+
|mobileone_s3 |imagenet |8.8M |
196+
|mobileone_s4 |imagenet |13.6M |
197+
198+
</div>
199+
</details>
200+
201+
202+
## Notes for 3D version
203+
204+
### Input size
205+
206+
Recommended input size for backbones can be calculated as: `K = pow(N, 2/3)`.
207+
Where N - is size for input image for the same model in 2D variant.
208+
209+
For example for N = 224, K = 32. For N = 512, K = 64.
210+
211+
### Strides
212+
213+
Typical strides for 2D case is 2 for H and W. It applied `depth` times (in almost all cases 5 times). So input image reduced from (224, 224) to (7, 7) on final layers. For 3D case because of very massive input, it's sometimes useful to control strides for every dimension independently. For this you can use input variable `strides`, which default values is: `strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2))`. Example:
214+
215+
Let's say you have input data of size: (224, 128, 12). You can use strides like that:
216+
((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)). Output shape for these strides will be: (7, 4, 1)
217+
```python
218+
import segmentation_models_pytorch_3d as smp
219+
import torch
220+
221+
model = smp.Unet(
222+
encoder_name="resnet50",
223+
in_channels=1,
224+
strides=((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)),
225+
classes=3,
226+
)
227+
228+
res = model(torch.randn(4, 1, 224, 128, 12))
229+
```
230+
231+
**Note**: Strides currently supported by `resnet`-family and `densenet` models with `Unet` decoder only.
232+
233+
### Related repositories
234+
235+
* [https://github.com/qubvel/segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) - original segmentation 2D repo
236+
* [segmentation_models_3D](https://github.com/ZFTurbo/classification_models_3D) - segmentation models in 3D for keras/tensorflow
237+
* [volumentations](https://github.com/ZFTurbo/volumentations) - 3D augmentations
238+
239+
## Citation
240+
241+
If you find this code useful, please cite it as:
242+
```
243+
@article{solovyev20223d,
244+
title={3D convolutional neural networks for stalled brain capillary detection},
245+
author={Solovyev, Roman and Kalinin, Alexandr A and Gabruseva, Tatiana},
246+
journal={Computers in Biology and Medicine},
247+
volume={141},
248+
pages={105089},
249+
year={2022},
250+
publisher={Elsevier},
251+
doi={10.1016/j.compbiomed.2021.105089}
252+
}
253+
```
254+
255+
## To Do List
256+
* Support for strides for all encoders
257+
* Add timm_ models

requirements.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
torch
2+
torchvision>=0.5.0
3+
pretrainedmodels==0.7.4
4+
efficientnet-pytorch==0.7.1
5+
timm==0.9.7
6+
tqdm
7+
pillow
8+
six
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from . import datasets
2+
from . import encoders
3+
from . import decoders
4+
from . import losses
5+
from . import metrics
6+
7+
from .decoders.unet import Unet
8+
from .decoders.unetplusplus import UnetPlusPlus
9+
from .decoders.manet import MAnet
10+
from .decoders.linknet import Linknet
11+
from .decoders.fpn import FPN
12+
from .decoders.pspnet import PSPNet
13+
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
14+
from .decoders.pan import PAN
15+
16+
from .__version__ import __version__
17+
18+
# some private imports for create_model function
19+
from typing import Optional as _Optional
20+
import torch as _torch
21+
22+
23+
def create_model(
24+
arch: str,
25+
encoder_name: str = "resnet34",
26+
encoder_weights: _Optional[str] = "imagenet",
27+
in_channels: int = 3,
28+
classes: int = 1,
29+
**kwargs,
30+
) -> _torch.nn.Module:
31+
"""Models entrypoint, allows to create any model architecture just with
32+
parameters, without using its class
33+
"""
34+
35+
archs = [
36+
Unet,
37+
UnetPlusPlus,
38+
MAnet,
39+
Linknet,
40+
FPN,
41+
PSPNet,
42+
DeepLabV3,
43+
DeepLabV3Plus,
44+
PAN,
45+
]
46+
archs_dict = {a.__name__.lower(): a for a in archs}
47+
try:
48+
model_class = archs_dict[arch.lower()]
49+
except KeyError:
50+
raise KeyError(
51+
"Wrong architecture type `{}`. Available options are: {}".format(
52+
arch,
53+
list(archs_dict.keys()),
54+
)
55+
)
56+
return model_class(
57+
encoder_name=encoder_name,
58+
encoder_weights=encoder_weights,
59+
in_channels=in_channels,
60+
classes=classes,
61+
**kwargs,
62+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
VERSION = (0, 3, 3)
2+
3+
__version__ = ".".join(map(str, VERSION))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .model import SegmentationModel
2+
3+
from .modules import (
4+
Conv3dReLU,
5+
Attention,
6+
)
7+
8+
from .heads import (
9+
SegmentationHead,
10+
ClassificationHead,
11+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch.nn as nn
2+
from .modules import Activation
3+
4+
5+
class SegmentationHead(nn.Sequential):
6+
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
7+
conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
8+
upsampling = nn.Upsample(scale_factor=upsampling, mode='trilinear') if upsampling > 1 else nn.Identity()
9+
activation = Activation(activation)
10+
super().__init__(conv3d, upsampling, activation)
11+
12+
13+
class ClassificationHead(nn.Sequential):
14+
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
15+
if pooling not in ("max", "avg"):
16+
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
17+
pool = nn.AdaptiveAvgPool3d(1) if pooling == "avg" else nn.AdaptiveMaxPool3d(1)
18+
flatten = nn.Flatten()
19+
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
20+
linear = nn.Linear(in_channels, classes, bias=True)
21+
activation = Activation(activation)
22+
super().__init__(pool, flatten, dropout, linear, activation)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch.nn as nn
2+
3+
4+
def initialize_decoder(module):
5+
for m in module.modules():
6+
7+
if isinstance(m, nn.Conv2d):
8+
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
9+
if m.bias is not None:
10+
nn.init.constant_(m.bias, 0)
11+
12+
elif isinstance(m, nn.BatchNorm2d):
13+
nn.init.constant_(m.weight, 1)
14+
nn.init.constant_(m.bias, 0)
15+
16+
elif isinstance(m, nn.Linear):
17+
nn.init.xavier_uniform_(m.weight)
18+
if m.bias is not None:
19+
nn.init.constant_(m.bias, 0)
20+
21+
22+
def initialize_head(module):
23+
for m in module.modules():
24+
if isinstance(m, (nn.Linear, nn.Conv2d)):
25+
nn.init.xavier_uniform_(m.weight)
26+
if m.bias is not None:
27+
nn.init.constant_(m.bias, 0)

0 commit comments

Comments
 (0)