Skip to content

Commit f929a18

Browse files
authored
Add new pretrained models (#257)
* add pretrained models
1 parent 7a9d4c5 commit f929a18

File tree

2 files changed

+78
-69
lines changed

2 files changed

+78
-69
lines changed

README.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,15 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
7474

7575
|Encoder |Weights |Params, M |
7676
|--------------------------------|:------------------------------:|:------------------------------:|
77-
|resnet18 |imagenet |11M |
77+
|resnet18 |imagenet<br>ssl*<br>swsl* |11M |
7878
|resnet34 |imagenet |21M |
79-
|resnet50 |imagenet |23M |
79+
|resnet50 |imagenet<br>ssl*<br>swsl* |23M |
8080
|resnet101 |imagenet |42M |
8181
|resnet152 |imagenet |58M |
82-
|resnext50_32x4d |imagenet |22M |
83-
|resnext101_32x8d |imagenet<br>instagram |86M |
84-
|resnext101_32x16d |instagram |191M |
82+
|resnext50_32x4d |imagenet<br>ssl*<br>swsl* |22M |
83+
|resnext101_32x4d |ssl<br>swsl |42M |
84+
|resnext101_32x8d |imagenet<br>instagram<br>ssl*<br>swsl*|86M |
85+
|resnext101_32x16d |instagram<br>ssl*<br>swsl* |191M |
8586
|resnext101_32x32d |instagram |466M |
8687
|resnext101_32x48d |instagram |826M |
8788
|dpn68 |imagenet |11M |
@@ -131,6 +132,8 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
131132
|timm-efficientnet-b8 |imagenet<br>advprop |84M |
132133
|timm-efficientnet-l2 |noisy-student |474M |
133134

135+
\* `ssl`, `wsl` from [here](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models).
136+
134137
### Models API <a name="api"></a>
135138

136139
- `model.encoder` - pretrained backbone to extract features of different spatial resolution

segmentation_models_pytorch/encoders/resnet.py

+70-64
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
2323
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
2424
"""
25+
from copy import deepcopy
2526

2627
import torch.nn as nn
2728

@@ -69,6 +70,59 @@ def load_state_dict(self, state_dict, **kwargs):
6970
super().load_state_dict(state_dict, **kwargs)
7071

7172

73+
new_settings = {
74+
"resnet18": {
75+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth",
76+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
77+
},
78+
"resnet50": {
79+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth",
80+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
81+
},
82+
"resnext50_32x4d": {
83+
"imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
84+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth",
85+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth",
86+
},
87+
"resnext101_32x4d": {
88+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth",
89+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
90+
},
91+
"resnext101_32x8d": {
92+
"imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
93+
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
94+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth",
95+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth",
96+
},
97+
"resnext101_32x16d": {
98+
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
99+
"ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth",
100+
"swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth",
101+
},
102+
"resnext101_32x32d": {
103+
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
104+
},
105+
"resnext101_32x48d": {
106+
"instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
107+
}
108+
}
109+
110+
pretrained_settings = deepcopy(pretrained_settings)
111+
for model_name, sources in new_settings.items():
112+
if model_name not in pretrained_settings:
113+
pretrained_settings[model_name] = {}
114+
115+
for source_name, source_url in sources.items():
116+
pretrained_settings[model_name][source_name] = {
117+
"url": source_url,
118+
'input_size': [3, 224, 224],
119+
'input_range': [0, 1],
120+
'mean': [0.485, 0.456, 0.406],
121+
'std': [0.229, 0.224, 0.225],
122+
'num_classes': 1000
123+
}
124+
125+
72126
resnet_encoders = {
73127
"resnet18": {
74128
"encoder": ResNetEncoder,
@@ -117,17 +171,7 @@ def load_state_dict(self, state_dict, **kwargs):
117171
},
118172
"resnext50_32x4d": {
119173
"encoder": ResNetEncoder,
120-
"pretrained_settings": {
121-
"imagenet": {
122-
"url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
123-
"input_space": "RGB",
124-
"input_size": [3, 224, 224],
125-
"input_range": [0, 1],
126-
"mean": [0.485, 0.456, 0.406],
127-
"std": [0.229, 0.224, 0.225],
128-
"num_classes": 1000,
129-
}
130-
},
174+
"pretrained_settings": pretrained_settings["resnext50_32x4d"],
131175
"params": {
132176
"out_channels": (3, 64, 256, 512, 1024, 2048),
133177
"block": Bottleneck,
@@ -136,28 +180,20 @@ def load_state_dict(self, state_dict, **kwargs):
136180
"width_per_group": 4,
137181
},
138182
},
139-
"resnext101_32x8d": {
183+
"resnext101_32x4d": {
140184
"encoder": ResNetEncoder,
141-
"pretrained_settings": {
142-
"imagenet": {
143-
"url": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
144-
"input_space": "RGB",
145-
"input_size": [3, 224, 224],
146-
"input_range": [0, 1],
147-
"mean": [0.485, 0.456, 0.406],
148-
"std": [0.229, 0.224, 0.225],
149-
"num_classes": 1000,
150-
},
151-
"instagram": {
152-
"url": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth",
153-
"input_space": "RGB",
154-
"input_size": [3, 224, 224],
155-
"input_range": [0, 1],
156-
"mean": [0.485, 0.456, 0.406],
157-
"std": [0.229, 0.224, 0.225],
158-
"num_classes": 1000,
159-
},
185+
"pretrained_settings": pretrained_settings["resnext101_32x4d"],
186+
"params": {
187+
"out_channels": (3, 64, 256, 512, 1024, 2048),
188+
"block": Bottleneck,
189+
"layers": [3, 4, 23, 3],
190+
"groups": 32,
191+
"width_per_group": 4,
160192
},
193+
},
194+
"resnext101_32x8d": {
195+
"encoder": ResNetEncoder,
196+
"pretrained_settings": pretrained_settings["resnext101_32x8d"],
161197
"params": {
162198
"out_channels": (3, 64, 256, 512, 1024, 2048),
163199
"block": Bottleneck,
@@ -168,17 +204,7 @@ def load_state_dict(self, state_dict, **kwargs):
168204
},
169205
"resnext101_32x16d": {
170206
"encoder": ResNetEncoder,
171-
"pretrained_settings": {
172-
"instagram": {
173-
"url": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth",
174-
"input_space": "RGB",
175-
"input_size": [3, 224, 224],
176-
"input_range": [0, 1],
177-
"mean": [0.485, 0.456, 0.406],
178-
"std": [0.229, 0.224, 0.225],
179-
"num_classes": 1000,
180-
}
181-
},
207+
"pretrained_settings": pretrained_settings["resnext101_32x16d"],
182208
"params": {
183209
"out_channels": (3, 64, 256, 512, 1024, 2048),
184210
"block": Bottleneck,
@@ -189,17 +215,7 @@ def load_state_dict(self, state_dict, **kwargs):
189215
},
190216
"resnext101_32x32d": {
191217
"encoder": ResNetEncoder,
192-
"pretrained_settings": {
193-
"instagram": {
194-
"url": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth",
195-
"input_space": "RGB",
196-
"input_size": [3, 224, 224],
197-
"input_range": [0, 1],
198-
"mean": [0.485, 0.456, 0.406],
199-
"std": [0.229, 0.224, 0.225],
200-
"num_classes": 1000,
201-
}
202-
},
218+
"pretrained_settings": pretrained_settings["resnext101_32x32d"],
203219
"params": {
204220
"out_channels": (3, 64, 256, 512, 1024, 2048),
205221
"block": Bottleneck,
@@ -210,17 +226,7 @@ def load_state_dict(self, state_dict, **kwargs):
210226
},
211227
"resnext101_32x48d": {
212228
"encoder": ResNetEncoder,
213-
"pretrained_settings": {
214-
"instagram": {
215-
"url": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth",
216-
"input_space": "RGB",
217-
"input_size": [3, 224, 224],
218-
"input_range": [0, 1],
219-
"mean": [0.485, 0.456, 0.406],
220-
"std": [0.229, 0.224, 0.225],
221-
"num_classes": 1000,
222-
}
223-
},
229+
"pretrained_settings": pretrained_settings["resnext101_32x48d"],
224230
"params": {
225231
"out_channels": (3, 64, 256, 512, 1024, 2048),
226232
"block": Bottleneck,

0 commit comments

Comments
 (0)