Skip to content

Commit 9e6533b

Browse files
committed
Add parse and remove large files
1 parent e8c63e4 commit 9e6533b

16 files changed

+574
-91
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Custom
2+
*.ckpt
23
*.nii.gz
34
*.csv
45

Diff for: CHANGELOG

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
v2.5
2+
* Added Parse vessel segmentation output
3+
* Deprecated some old code
4+
* Move large files to separate release
5+
16
v2.4
27

38
* Better input/output flow, option to use less memory by not showing activations

Diff for: README.md

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Modified EfficientDet Segmentation (MEDSeg)
2-
Official repository for reproducing lung, COVID-19 and airway automated segmentation using our MEDSeg model.
2+
Official repository for reproducing lung, COVID-19, airway and pulmonary artery automated segmentation using our MEDSeg model.
33

44
The publication original publication for this method, **Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models**, has been published at the 17th International Symposium on Medical Information Processing and Analysis (SIPAIM 2021), and won the "SIPAIM Society Award".
55
http://dx.doi.org/10.1117/12.2606118
@@ -10,8 +10,10 @@ https://www.youtube.com/watch?v=PlhNUD0Y4hg
1010
We have also applied this model in the ATM22 Challenge (https://atm22.grand-challenge.org/). Airway segmentation is included, with a CLI argument (--atm_mode) to only segment the airway, using less memory. A short paper about this is published in arXiv **Open-source tool for Airway Segmentation in
1111
Computed Tomography using 2.5D Modified EfficientDet: Contribution to the ATM22 Challenge**: https://arxiv.org/pdf/2209.15094.pdf
1212

13+
We have also trained this model to the PARSE Challenge (https://parse2022.grand-challenge.org/), (Pulmonary Artery segmentation). Pulmonary artery labels will be included in the outputs. The model achieved around 0.7 Dice in testing. An paper detailing this application will be published in the future.
14+
1315
## Citation
14-
* **COVID-19 segmentation**: Carmo, Diedre, et al. "Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models." 17th International Symposium on Medical Information Processing and Analysis. Vol. 12088. SPIE, 2021.
16+
* **COVID-19 segmentation and method in general**: Carmo, Diedre, et al. "Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models." 17th International Symposium on Medical Information Processing and Analysis. Vol. 12088. SPIE, 2021.
1517

1618
* @inproceedings{carmo2021multitasking,\
1719
title={Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models},\
@@ -68,6 +70,8 @@ All additional required libraries and the tool itself will be installed with the
6870

6971
If you use virtual environments, it is safer to install in a new virtual environment to avoid conflicts.
7072

73+
Finally, due to the large size of network weights, you need to go into the Release in this repository, download the data.zip file and extract the .ckpt files inside the medseg folder. The .ckpt files should be in the same directory level as the run.py file.
74+
7175
## Running
7276

7377
To run, just call it in a terminal.
@@ -91,3 +95,7 @@ If you have any problems, make sure your pip is the same from your miniconda ins
9195
by checking if pip --version points to the miniconda directory.
9296

9397
If you have any issues, feel free to create an issue on this repository.
98+
99+
### Known Issue
100+
101+
"Long prediction" mode is not working due to recent changes in the architecutre. However not using it should be enough for most cases, Long Prediction uses more models in the final ensemble.

Diff for: medseg/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.4.1"
1+
__version__ = "2.5.0"

Diff for: medseg/airway.ckpt

-72.4 MB
Binary file not shown.

Diff for: medseg/architecture.py

+78-5
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,62 @@
1+
'''
2+
If you use please cite:
3+
CARMO, Diedre et al. Multitasking segmentation of lung and COVID-19 findings in CT scans using modified EfficientDet, UNet and MobileNetV3 models. In: 17th International Symposium on Medical Information Processing and Analysis. SPIE, 2021. p. 65-74.
4+
'''
15
import torch
26
from torch import nn
3-
7+
from efficientnet_pytorch.utils import round_filters
48
from medseg.edet.modeling_efficientdet import EfficientDetForSemanticSegmentation
59

610

711
class MEDSeg(nn.Module):
8-
def __init__(self, nin=3, nout=3, apply_sigmoid=False, dropout=None, backbone="effnet", pretrained=True, expand_bifpn="conv"):
12+
def __init__(self, nin=3, nout=3, apply_sigmoid=False, dropout=None, backbone="effnet", pretrained=True, expand_bifpn="upsample", imnet_norm=False,
13+
num_classes_atm=None,
14+
num_classes_rec=None,
15+
num_classes_vessel=None,
16+
stem_replacement=False,
17+
new_latent_space=False,
18+
compound_coef=4): # compound always has been 4 by default before
919
super().__init__()
20+
print("WARNING: default expand_bifpn changed to upsample!")
1021
self.model = EfficientDetForSemanticSegmentation(num_classes=nout,
1122
load_weights=pretrained,
1223
apply_sigmoid=apply_sigmoid,
1324
expand_bifpn=expand_bifpn,
1425
dropout=dropout,
15-
backbone=backbone)
26+
backbone=backbone,
27+
compound_coef=compound_coef,
28+
num_classes_atm=num_classes_atm,
29+
num_classes_rec=num_classes_rec,
30+
num_classes_vessel=num_classes_vessel,
31+
new_latent_space=new_latent_space)
32+
33+
self.feature_adapters = self.model.feature_adapters
34+
35+
if imnet_norm:
36+
print("Performing imnet normalization internally, assuming inputs between 1 and 0")
37+
self.imnet_norm = ImNetNorm()
38+
else:
39+
self.imnet_norm = nn.Identity()
1640

1741
self.nin = nin
1842
if self.nin not in [1, 3]:
1943
self.in_conv = nn.Conv2d(in_channels=self.nin, out_channels=3, kernel_size=1, stride=1, padding=0, bias=False)
2044

21-
print(f"MEDSeg initialized. nin: {nin}, nout: {nout}, apply_sigmoid: {apply_sigmoid}, dropout: {dropout}, backbone: {backbone}, pretrained: {pretrained}, expand_bifpn: {expand_bifpn}, align DISABLED")
45+
if stem_replacement:
46+
assert backbone == "effnet", "Stem replacement only valid for efficientnet"
47+
print("Performing stem replacement on EfficientNet backbone (this runs after initialization)")
48+
self.model.backbone_net.model._conv_stem = EffNet3DStemReplacement(self.model.backbone_net.model)
2249

50+
print(f"MEDSeg initialized. nin: {nin}, nout: {nout}, apply_sigmoid: {apply_sigmoid}, dropout: {dropout},"
51+
f"backbone: {backbone}, pretrained: {pretrained}, expand_bifpn: {expand_bifpn}, pad align DISABLED, stem_replacement {stem_replacement}"
52+
f"new latent space extraction {new_latent_space}")
53+
54+
def extract_backbone_features(self, inputs):
55+
return self.model.extract_backbone_features(inputs)
56+
57+
def extract_bifpn_features(self, features):
58+
return self.model.extract_bifpn_features(features)
59+
2360
def forward(self, x):
2461
if self.nin == 1:
2562
x_in = torch.zeros(size=(x.shape[0], 3) + x.shape[2:], device=x.device, dtype=x.dtype)
@@ -32,5 +69,41 @@ def forward(self, x):
3269
else:
3370
x = self.in_conv(x)
3471

72+
x = self.imnet_norm(x)
73+
3574
return self.model(x)
36-
75+
76+
77+
class EffNet3DStemReplacement(nn.Module):
78+
def __init__(self, effnet_pytorch_instance):
79+
super().__init__()
80+
out_channels = round_filters(32, effnet_pytorch_instance._global_params)
81+
self.conv = nn.Conv3d(1, out_channels, kernel_size=3, stride=1, padding="valid", bias=False)
82+
self.pad = nn.ZeroPad2d(1)
83+
self.conv_pool = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2, padding=0, bias=False)
84+
85+
def forward(self, x):
86+
'''
87+
x is 4D batch but will be treated as 5D
88+
'''
89+
x = self.conv(x.unsqueeze(1)).squeeze(2) # [B, 3, X, Y] -> [B, 1, 3, X, Y]
90+
# -> [B, OUT_CH, 1, X, Y] -> [B, OUT_CH, X, Y]
91+
x = self.pad(x)
92+
x = self.conv_pool(x)
93+
return x
94+
95+
96+
class ImNetNorm():
97+
'''
98+
Assumes input between 1 and 0
99+
'''
100+
def __init__(self):
101+
self.mean = [0.485, 0.456, 0.406]
102+
self.std = [0.229, 0.224, 0.225]
103+
104+
def __call__(self, xim):
105+
with torch.no_grad():
106+
for i in range(3):
107+
xim[:, i] = (xim[:, i] - self.mean[i])/self.std[i]
108+
109+
return xim

Diff for: medseg/best_coedet.ckpt

-73.1 MB
Binary file not shown.

Diff for: medseg/convnext.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
from timm.models.layers import trunc_normal_, DropPath
11+
from timm.models.registry import register_model
12+
13+
14+
class Block(nn.Module):
15+
r""" ConvNeXt Block. There are two equivalent implementations:
16+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
17+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
18+
We use (2) as we find it slightly faster in PyTorch
19+
20+
Args:
21+
dim (int): Number of input channels.
22+
drop_path (float): Stochastic depth rate. Default: 0.0
23+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
24+
"""
25+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
26+
super().__init__()
27+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
28+
self.norm = LayerNorm(dim, eps=1e-6)
29+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
30+
self.act = nn.GELU()
31+
self.pwconv2 = nn.Linear(4 * dim, dim)
32+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
33+
requires_grad=True) if layer_scale_init_value > 0 else None
34+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
35+
36+
def forward(self, x):
37+
input = x
38+
x = self.dwconv(x)
39+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
40+
x = self.norm(x)
41+
x = self.pwconv1(x)
42+
x = self.act(x)
43+
x = self.pwconv2(x)
44+
if self.gamma is not None:
45+
x = self.gamma * x
46+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
47+
48+
x = input + self.drop_path(x)
49+
return x
50+
51+
class ConvNeXt(nn.Module):
52+
r""" ConvNeXt
53+
A PyTorch impl of : `A ConvNet for the 2020s` -
54+
https://arxiv.org/pdf/2201.03545.pdf
55+
56+
Args:
57+
in_chans (int): Number of input image channels. Default: 3
58+
num_classes (int): Number of classes for classification head. Default: 1000
59+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
60+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
61+
drop_path_rate (float): Stochastic depth rate. Default: 0.
62+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
63+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
64+
"""
65+
def __init__(self, in_chans=3, num_classes=1000,
66+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
67+
layer_scale_init_value=1e-6, head_init_scale=1.,
68+
):
69+
super().__init__()
70+
71+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
72+
stem = nn.Sequential(
73+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
74+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
75+
)
76+
self.downsample_layers.append(stem)
77+
for i in range(3):
78+
downsample_layer = nn.Sequential(
79+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
80+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
81+
)
82+
self.downsample_layers.append(downsample_layer)
83+
84+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
85+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
86+
cur = 0
87+
for i in range(4):
88+
stage = nn.Sequential(
89+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
90+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
91+
)
92+
self.stages.append(stage)
93+
cur += depths[i]
94+
95+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
96+
self.head = nn.Linear(dims[-1], num_classes)
97+
98+
self.apply(self._init_weights)
99+
self.head.weight.data.mul_(head_init_scale)
100+
self.head.bias.data.mul_(head_init_scale)
101+
102+
def _init_weights(self, m):
103+
if isinstance(m, (nn.Conv2d, nn.Linear)):
104+
trunc_normal_(m.weight, std=.02)
105+
nn.init.constant_(m.bias, 0)
106+
107+
def forward_features(self, x):
108+
for i in range(4):
109+
x = self.downsample_layers[i](x)
110+
x = self.stages[i](x)
111+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
112+
113+
def forward_seg_features(self, x, convnext_expansion_scale, range_limit=3):
114+
outs = []
115+
for i in range(range_limit):
116+
x = self.downsample_layers[i](x)
117+
if convnext_expansion_scale <= 0:
118+
outs.append(self.stages[i](x))
119+
else:
120+
outs.append(F.upsample_bilinear(self.stages[i](x), scale_factor=convnext_expansion_scale))
121+
return outs
122+
123+
def forward(self, x):
124+
x = self.forward_features(x)
125+
x = self.head(x)
126+
return x
127+
128+
class LayerNorm(nn.Module):
129+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
130+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
131+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
132+
with shape (batch_size, channels, height, width).
133+
"""
134+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
135+
super().__init__()
136+
self.weight = nn.Parameter(torch.ones(normalized_shape))
137+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
138+
self.eps = eps
139+
self.data_format = data_format
140+
if self.data_format not in ["channels_last", "channels_first"]:
141+
raise NotImplementedError
142+
self.normalized_shape = (normalized_shape, )
143+
144+
def forward(self, x):
145+
if self.data_format == "channels_last":
146+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
147+
elif self.data_format == "channels_first":
148+
u = x.mean(1, keepdim=True)
149+
s = (x - u).pow(2).mean(1, keepdim=True)
150+
x = (x - u) / torch.sqrt(s + self.eps)
151+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
152+
return x
153+
154+
155+
model_urls = {
156+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
157+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
158+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
159+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
160+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
161+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
162+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
163+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
164+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
165+
}
166+
167+
@register_model
168+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
169+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
170+
if pretrained:
171+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
172+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
173+
model.load_state_dict(checkpoint["model"])
174+
return model
175+
176+
@register_model
177+
def convnext_small(pretrained=False,in_22k=False, **kwargs):
178+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
179+
if pretrained:
180+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
181+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
182+
model.load_state_dict(checkpoint["model"])
183+
return model
184+
185+
@register_model
186+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
187+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
188+
if pretrained:
189+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
190+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191+
model.load_state_dict(checkpoint["model"])
192+
return model
193+
194+
@register_model
195+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
196+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
197+
if pretrained:
198+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
199+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200+
model.load_state_dict(checkpoint["model"])
201+
return model
202+
203+
@register_model
204+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
205+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
206+
if pretrained:
207+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
208+
url = model_urls['convnext_xlarge_22k']
209+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
210+
model.load_state_dict(checkpoint["model"])
211+
return model

0 commit comments

Comments
 (0)