22
22
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23
23
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
24
"""
25
+ from copy import deepcopy
25
26
26
27
import torch .nn as nn
27
28
@@ -69,6 +70,59 @@ def load_state_dict(self, state_dict, **kwargs):
69
70
super ().load_state_dict (state_dict , ** kwargs )
70
71
71
72
73
+ new_settings = {
74
+ "resnet18" : {
75
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth" ,
76
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth"
77
+ },
78
+ "resnet50" : {
79
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth" ,
80
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth"
81
+ },
82
+ "resnext50_32x4d" : {
83
+ "imagenet" : "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" ,
84
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth" ,
85
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth" ,
86
+ },
87
+ "resnext101_32x4d" : {
88
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth" ,
89
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth"
90
+ },
91
+ "resnext101_32x8d" : {
92
+ "imagenet" : "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" ,
93
+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth" ,
94
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth" ,
95
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth" ,
96
+ },
97
+ "resnext101_32x16d" : {
98
+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth" ,
99
+ "ssl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth" ,
100
+ "swsl" : "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth" ,
101
+ },
102
+ "resnext101_32x32d" : {
103
+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" ,
104
+ },
105
+ "resnext101_32x48d" : {
106
+ "instagram" : "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" ,
107
+ }
108
+ }
109
+
110
+ pretrained_settings = deepcopy (pretrained_settings )
111
+ for model_name , sources in new_settings .items ():
112
+ if model_name not in pretrained_settings :
113
+ pretrained_settings [model_name ] = {}
114
+
115
+ for source_name , source_url in sources .items ():
116
+ pretrained_settings [model_name ][source_name ] = {
117
+ "url" : source_url ,
118
+ 'input_size' : [3 , 224 , 224 ],
119
+ 'input_range' : [0 , 1 ],
120
+ 'mean' : [0.485 , 0.456 , 0.406 ],
121
+ 'std' : [0.229 , 0.224 , 0.225 ],
122
+ 'num_classes' : 1000
123
+ }
124
+
125
+
72
126
resnet_encoders = {
73
127
"resnet18" : {
74
128
"encoder" : ResNetEncoder ,
@@ -117,17 +171,7 @@ def load_state_dict(self, state_dict, **kwargs):
117
171
},
118
172
"resnext50_32x4d" : {
119
173
"encoder" : ResNetEncoder ,
120
- "pretrained_settings" : {
121
- "imagenet" : {
122
- "url" : "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" ,
123
- "input_space" : "RGB" ,
124
- "input_size" : [3 , 224 , 224 ],
125
- "input_range" : [0 , 1 ],
126
- "mean" : [0.485 , 0.456 , 0.406 ],
127
- "std" : [0.229 , 0.224 , 0.225 ],
128
- "num_classes" : 1000 ,
129
- }
130
- },
174
+ "pretrained_settings" : pretrained_settings ["resnext50_32x4d" ],
131
175
"params" : {
132
176
"out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
133
177
"block" : Bottleneck ,
@@ -136,28 +180,20 @@ def load_state_dict(self, state_dict, **kwargs):
136
180
"width_per_group" : 4 ,
137
181
},
138
182
},
139
- "resnext101_32x8d " : {
183
+ "resnext101_32x4d " : {
140
184
"encoder" : ResNetEncoder ,
141
- "pretrained_settings" : {
142
- "imagenet" : {
143
- "url" : "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" ,
144
- "input_space" : "RGB" ,
145
- "input_size" : [3 , 224 , 224 ],
146
- "input_range" : [0 , 1 ],
147
- "mean" : [0.485 , 0.456 , 0.406 ],
148
- "std" : [0.229 , 0.224 , 0.225 ],
149
- "num_classes" : 1000 ,
150
- },
151
- "instagram" : {
152
- "url" : "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth" ,
153
- "input_space" : "RGB" ,
154
- "input_size" : [3 , 224 , 224 ],
155
- "input_range" : [0 , 1 ],
156
- "mean" : [0.485 , 0.456 , 0.406 ],
157
- "std" : [0.229 , 0.224 , 0.225 ],
158
- "num_classes" : 1000 ,
159
- },
185
+ "pretrained_settings" : pretrained_settings ["resnext101_32x4d" ],
186
+ "params" : {
187
+ "out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
188
+ "block" : Bottleneck ,
189
+ "layers" : [3 , 4 , 23 , 3 ],
190
+ "groups" : 32 ,
191
+ "width_per_group" : 4 ,
160
192
},
193
+ },
194
+ "resnext101_32x8d" : {
195
+ "encoder" : ResNetEncoder ,
196
+ "pretrained_settings" : pretrained_settings ["resnext101_32x8d" ],
161
197
"params" : {
162
198
"out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
163
199
"block" : Bottleneck ,
@@ -168,17 +204,7 @@ def load_state_dict(self, state_dict, **kwargs):
168
204
},
169
205
"resnext101_32x16d" : {
170
206
"encoder" : ResNetEncoder ,
171
- "pretrained_settings" : {
172
- "instagram" : {
173
- "url" : "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth" ,
174
- "input_space" : "RGB" ,
175
- "input_size" : [3 , 224 , 224 ],
176
- "input_range" : [0 , 1 ],
177
- "mean" : [0.485 , 0.456 , 0.406 ],
178
- "std" : [0.229 , 0.224 , 0.225 ],
179
- "num_classes" : 1000 ,
180
- }
181
- },
207
+ "pretrained_settings" : pretrained_settings ["resnext101_32x16d" ],
182
208
"params" : {
183
209
"out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
184
210
"block" : Bottleneck ,
@@ -189,17 +215,7 @@ def load_state_dict(self, state_dict, **kwargs):
189
215
},
190
216
"resnext101_32x32d" : {
191
217
"encoder" : ResNetEncoder ,
192
- "pretrained_settings" : {
193
- "instagram" : {
194
- "url" : "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" ,
195
- "input_space" : "RGB" ,
196
- "input_size" : [3 , 224 , 224 ],
197
- "input_range" : [0 , 1 ],
198
- "mean" : [0.485 , 0.456 , 0.406 ],
199
- "std" : [0.229 , 0.224 , 0.225 ],
200
- "num_classes" : 1000 ,
201
- }
202
- },
218
+ "pretrained_settings" : pretrained_settings ["resnext101_32x32d" ],
203
219
"params" : {
204
220
"out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
205
221
"block" : Bottleneck ,
@@ -210,17 +226,7 @@ def load_state_dict(self, state_dict, **kwargs):
210
226
},
211
227
"resnext101_32x48d" : {
212
228
"encoder" : ResNetEncoder ,
213
- "pretrained_settings" : {
214
- "instagram" : {
215
- "url" : "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" ,
216
- "input_space" : "RGB" ,
217
- "input_size" : [3 , 224 , 224 ],
218
- "input_range" : [0 , 1 ],
219
- "mean" : [0.485 , 0.456 , 0.406 ],
220
- "std" : [0.229 , 0.224 , 0.225 ],
221
- "num_classes" : 1000 ,
222
- }
223
- },
229
+ "pretrained_settings" : pretrained_settings ["resnext101_32x48d" ],
224
230
"params" : {
225
231
"out_channels" : (3 , 64 , 256 , 512 , 1024 , 2048 ),
226
232
"block" : Bottleneck ,
0 commit comments