7
7
import torch .nn as nn
8
8
import torch .nn .functional as F
9
9
10
- from .utils import (get_width_and_height_from_size , load_pretrained_weights , get_model_params )
10
+ from .resnet import StdConv2d
11
+ from .utils import (get_width_and_height_from_size , load_pretrained_weights ,
12
+ get_model_params )
11
13
12
- VALID_MODELS = ('ViT-B_16' , 'ViT-B_32' , 'ViT-L_16' , 'ViT-L_32' )
14
+ VALID_MODELS = ('ViT-B_16' , 'ViT-B_32' , 'ViT-L_16' , 'ViT-L_32' , 'R50+ViT-B_16' )
13
15
14
16
15
17
class PositionEmbs (nn .Module ):
16
18
def __init__ (self , num_patches , emb_dim , dropout_rate = 0.1 ):
17
19
super (PositionEmbs , self ).__init__ ()
18
- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + 1 , emb_dim ))
20
+ self .pos_embedding = nn .Parameter (
21
+ torch .randn (1 , num_patches + 1 , emb_dim ))
19
22
if dropout_rate > 0 :
20
23
self .dropout = nn .Dropout (dropout_rate )
21
24
else :
@@ -109,11 +112,18 @@ def forward(self, x):
109
112
110
113
111
114
class EncoderBlock (nn .Module ):
112
- def __init__ (self , in_dim , mlp_dim , num_heads , dropout_rate = 0.1 , attn_dropout_rate = 0.1 ):
115
+ def __init__ (self ,
116
+ in_dim ,
117
+ mlp_dim ,
118
+ num_heads ,
119
+ dropout_rate = 0.1 ,
120
+ attn_dropout_rate = 0.1 ):
113
121
super (EncoderBlock , self ).__init__ ()
114
122
115
123
self .norm1 = nn .LayerNorm (in_dim )
116
- self .attn = SelfAttention (in_dim , heads = num_heads , dropout_rate = attn_dropout_rate )
124
+ self .attn = SelfAttention (in_dim ,
125
+ heads = num_heads ,
126
+ dropout_rate = attn_dropout_rate )
117
127
if dropout_rate > 0 :
118
128
self .dropout = nn .Dropout (dropout_rate )
119
129
else :
@@ -154,7 +164,8 @@ def __init__(self,
154
164
in_dim = emb_dim
155
165
self .encoder_layers = nn .ModuleList ()
156
166
for i in range (num_layers ):
157
- layer = EncoderBlock (in_dim , mlp_dim , num_heads , dropout_rate , attn_dropout_rate )
167
+ layer = EncoderBlock (in_dim , mlp_dim , num_heads , dropout_rate ,
168
+ attn_dropout_rate )
158
169
self .encoder_layers .append (layer )
159
170
self .norm = nn .LayerNorm (in_dim )
160
171
@@ -190,21 +201,33 @@ def __init__(self, params=None):
190
201
super (VisionTransformer , self ).__init__ ()
191
202
self ._params = params
192
203
193
- self .embedding = nn .Conv2d (3 , self ._params .emb_dim , kernel_size = self .patch_size , stride = self .patch_size )
204
+ if self ._params .resnet :
205
+ self .resnet = self ._params .resnet ()
206
+ self .embedding = nn .Conv2d (self .resnet .width * 16 ,
207
+ self ._params .emb_dim ,
208
+ kernel_size = 1 ,
209
+ stride = 1 )
210
+ else :
211
+ self .embedding = nn .Conv2d (3 ,
212
+ self ._params .emb_dim ,
213
+ kernel_size = self .patch_size ,
214
+ stride = self .patch_size )
194
215
# class token
195
216
self .cls_token = nn .Parameter (torch .zeros (1 , 1 , self ._params .emb_dim ))
196
217
197
218
# transformer
198
- self .transformer = Encoder (num_patches = self .num_patches ,
199
- emb_dim = self ._params .emb_dim ,
200
- mlp_dim = self ._params .mlp_dim ,
201
- num_layers = self ._params .num_layers ,
202
- num_heads = self ._params .num_heads ,
203
- dropout_rate = self ._params .dropout_rate ,
204
- attn_dropout_rate = self ._params .attn_dropout_rate )
219
+ self .transformer = Encoder (
220
+ num_patches = self .num_patches ,
221
+ emb_dim = self ._params .emb_dim ,
222
+ mlp_dim = self ._params .mlp_dim ,
223
+ num_layers = self ._params .num_layers ,
224
+ num_heads = self ._params .num_heads ,
225
+ dropout_rate = self ._params .dropout_rate ,
226
+ attn_dropout_rate = self ._params .attn_dropout_rate )
205
227
206
228
# classfier
207
- self .classifier = nn .Linear (self ._params .emb_dim , self ._params .num_classes )
229
+ self .classifier = nn .Linear (self ._params .emb_dim ,
230
+ self ._params .num_classes )
208
231
209
232
@property
210
233
def image_size (self ):
@@ -218,10 +241,16 @@ def patch_size(self):
218
241
def num_patches (self ):
219
242
h , w = self .image_size
220
243
fh , fw = self .patch_size
221
- gh , gw = h // fh , w // fw
244
+ if hasattr (self , 'resnet' ):
245
+ gh , gw = h // fh // self .resnet .downsample , w // fw // self .resnet .downsample
246
+ else :
247
+ gh , gw = h // fh , w // fw
222
248
return gh * gw
223
249
224
250
def extract_features (self , x ):
251
+ if hasattr (self , 'resnet' ):
252
+ x = self .resnet (x )
253
+
225
254
emb = self .embedding (x ) # (n, c, gh, gw)
226
255
emb = emb .permute (0 , 2 , 3 , 1 ) # (n, gh, hw, c)
227
256
b , h , w , c = emb .shape
@@ -266,7 +295,12 @@ def from_name(cls, model_name, in_channels=3, **override_params):
266
295
return model
267
296
268
297
@classmethod
269
- def from_pretrained (cls , model_name , weights_path = None , in_channels = 3 , num_classes = 1000 , ** override_params ):
298
+ def from_pretrained (cls ,
299
+ model_name ,
300
+ weights_path = None ,
301
+ in_channels = 3 ,
302
+ num_classes = 1000 ,
303
+ ** override_params ):
270
304
"""create an vision transformer model according to name.
271
305
Args:
272
306
model_name (str): Name for vision transformer.
@@ -288,8 +322,13 @@ def from_pretrained(cls, model_name, weights_path=None, in_channels=3, num_class
288
322
Returns:
289
323
A pretrained vision transformer model.
290
324
"""
291
- model = cls .from_name (model_name , num_classes = num_classes , ** override_params )
292
- load_pretrained_weights (model , model_name , weights_path = weights_path , load_fc = (num_classes == 1000 ))
325
+ model = cls .from_name (model_name ,
326
+ num_classes = num_classes ,
327
+ ** override_params )
328
+ load_pretrained_weights (model ,
329
+ model_name ,
330
+ weights_path = weights_path ,
331
+ load_fc = (num_classes == 1000 ))
293
332
model ._change_in_channels (in_channels )
294
333
return model
295
334
@@ -302,15 +341,24 @@ def _check_model_name_is_valid(cls, model_name):
302
341
bool: Is a valid name or not.
303
342
"""
304
343
if model_name not in VALID_MODELS :
305
- raise ValueError ('model_name should be one of: ' + ', ' .join (VALID_MODELS ))
344
+ raise ValueError ('model_name should be one of: ' +
345
+ ', ' .join (VALID_MODELS ))
306
346
307
347
def _change_in_channels (self , in_channels ):
308
348
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
309
349
Args:
310
350
in_channels (int): Input data's channel number.
311
351
"""
312
352
if in_channels != 3 :
313
- self .embedding = nn .Conv2d (in_channels ,
314
- self ._params .emb_dim ,
315
- kernel_size = self .patch_size ,
316
- stride = self .patch_size )
353
+ if hasattr (self , 'resnet' ):
354
+ self .resnet .root ['conv' ] = StdConv2d (in_channels ,
355
+ self .resnet .width ,
356
+ kernel_size = 7 ,
357
+ stride = 2 ,
358
+ bias = False ,
359
+ padding = 3 )
360
+ else :
361
+ self .embedding = nn .Conv2d (in_channels ,
362
+ self ._params .emb_dim ,
363
+ kernel_size = self .patch_size ,
364
+ stride = self .patch_size )
0 commit comments