Skip to content

Commit c3cf3ae

Browse files
committed
Update 1.0.3 version and add ResNet-Vit
1 parent a60ca20 commit c3cf3ae

File tree

8 files changed

+317
-66
lines changed

8 files changed

+317
-66
lines changed

README.md

+11-11
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ model = VisionTransformer.from_pretrained('ViT-B_16')
5555

5656
Default hyper parameters:
5757

58-
| Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 |
59-
| ----------------- | -------- | -------- | -------- | -------- |
60-
| image_size | 384 | 384 | 384 | 384 |
61-
| patch_size | 16 | 32 | 16 | 32 |
62-
| emb_dim | 768 | 768 | 1024 | 1024 |
63-
| mlp_dim | 3072 | 3072 | 4096 | 4096 |
64-
| num_heads | 12 | 12 | 16 | 16 |
65-
| num_layers | 12 | 12 | 24 | 24 |
66-
| num_classes | 1000 | 1000 | 1000 | 1000 |
67-
| attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 |
68-
| dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 |
58+
| Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 | R50+ViT-B_16 |
59+
| ----------------- | -------- | -------- | -------- | -------- | ------------ |
60+
| image_size | 384 | 384 | 384 | 384 | 384 |
61+
| patch_size | 16 | 32 | 16 | 32 | 1 |
62+
| emb_dim | 768 | 768 | 1024 | 1024 | 768 |
63+
| mlp_dim | 3072 | 3072 | 4096 | 4096 | 3072 |
64+
| num_heads | 12 | 12 | 16 | 16 | 12 |
65+
| num_layers | 12 | 12 | 24 | 24 | 12 |
66+
| num_classes | 1000 | 1000 | 1000 | 1000 | 1000 |
67+
| attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
68+
| dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 |
6969

7070
If you need to modify these hyper parameters, please use:
7171

jax_to_pytorch/convert_jax_to_pt/load_jax_weight.py jax_to_pytorch/convert_jax_to_pt/load_jax_weights.py

+18
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def replace_names(names):
5353
new_names.append('classifier')
5454
elif name == 'cls':
5555
new_names.append('cls_token')
56+
elif name == 'block1':
57+
new_names.append('resnet.body.block1')
58+
elif name == 'block2':
59+
new_names.append('resnet.body.block2')
60+
elif name == 'block3':
61+
new_names.append('resnet.body.block3')
62+
elif name == 'conv_root':
63+
new_names.append('resnet.root.conv')
64+
elif name == 'gn_root':
65+
new_names.append('resnet.root.gn')
66+
elif name == 'conv_proj':
67+
new_names.append('downsample')
5668
else:
5769
new_names.append(name)
5870
return new_names
@@ -81,9 +93,15 @@ def convert_jax_pytorch(keys, values):
8193
feat_dim, num_heads, head_dim = tensor_value.shape
8294
# for multi head attention q/k/v weight
8395
tensor_value = tensor_value
96+
elif torch_names[-1] == 'weight' and 'gn' in torch_names[-2]:
97+
# for multi head attention q/k/v weight
98+
tensor_value = tensor_value.reshape(tensor_value.shape[-1])
8499
elif num_dim == 2 and torch_names[-1] == 'bias' and torch_names[-2] in ['query', 'key', 'value']:
85100
# for multi head attention q/k/v bias
86101
tensor_value = tensor_value
102+
elif torch_names[-1] == 'bias' and 'gn' in torch_names[-2]:
103+
# for multi head attention q/k/v weight
104+
tensor_value = tensor_value.reshape(tensor_value.shape[-1])
87105
elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] == 'out':
88106
# for multi head attention out weight
89107
tensor_value = tensor_value

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
AUTHOR = 'ZHANG Zhi'
2222
REQUIRES_PYTHON = '>=3.5.0'
23-
VERSION = '1.0.2'
23+
VERSION = '1.0.3'
2424

2525
# What packages are required for this module to be executed?
2626
REQUIRED = [

tests/test_model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# -- fixtures -------------------------------------------------------------------------------------
1313

1414

15-
@pytest.fixture(scope='module', params=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32'])
15+
@pytest.fixture(
16+
scope='module',
17+
params=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16'])
1618
def model(request):
1719
return request.param
1820

@@ -24,7 +26,8 @@ def pretrained(request):
2426

2527
@pytest.fixture(scope='function')
2628
def net(model, pretrained):
27-
return VisionTransformer.from_pretrained(model) if pretrained else VisionTransformer.from_name(model)
29+
return VisionTransformer.from_pretrained(
30+
model) if pretrained else VisionTransformer.from_name(model)
2831

2932

3033
# -- tests ----------------------------------------------------------------------------------------
@@ -36,6 +39,7 @@ def test_forward(net):
3639
output = net(data)
3740
assert not torch.isnan(output).any()
3841

42+
3943
@pytest.mark.parametrize('img_size', [224, 256, 512])
4044
def test_hyper_params(model, img_size):
4145
"""Test `.forward()` doesn't throw an error with different input size"""

vision_transformer_pytorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.0.2"
1+
__version__ = "1.0.3"
22
from .model import VisionTransformer, VALID_MODELS
33
from .utils import (
44
Params,

vision_transformer_pytorch/model.py

+72-24
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99

10-
from .utils import (get_width_and_height_from_size, load_pretrained_weights, get_model_params)
10+
from .resnet import StdConv2d
11+
from .utils import (get_width_and_height_from_size, load_pretrained_weights,
12+
get_model_params)
1113

12-
VALID_MODELS = ('ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32')
14+
VALID_MODELS = ('ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16')
1315

1416

1517
class PositionEmbs(nn.Module):
1618
def __init__(self, num_patches, emb_dim, dropout_rate=0.1):
1719
super(PositionEmbs, self).__init__()
18-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
20+
self.pos_embedding = nn.Parameter(
21+
torch.randn(1, num_patches + 1, emb_dim))
1922
if dropout_rate > 0:
2023
self.dropout = nn.Dropout(dropout_rate)
2124
else:
@@ -109,11 +112,18 @@ def forward(self, x):
109112

110113

111114
class EncoderBlock(nn.Module):
112-
def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1):
115+
def __init__(self,
116+
in_dim,
117+
mlp_dim,
118+
num_heads,
119+
dropout_rate=0.1,
120+
attn_dropout_rate=0.1):
113121
super(EncoderBlock, self).__init__()
114122

115123
self.norm1 = nn.LayerNorm(in_dim)
116-
self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate)
124+
self.attn = SelfAttention(in_dim,
125+
heads=num_heads,
126+
dropout_rate=attn_dropout_rate)
117127
if dropout_rate > 0:
118128
self.dropout = nn.Dropout(dropout_rate)
119129
else:
@@ -154,7 +164,8 @@ def __init__(self,
154164
in_dim = emb_dim
155165
self.encoder_layers = nn.ModuleList()
156166
for i in range(num_layers):
157-
layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate)
167+
layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate,
168+
attn_dropout_rate)
158169
self.encoder_layers.append(layer)
159170
self.norm = nn.LayerNorm(in_dim)
160171

@@ -190,21 +201,33 @@ def __init__(self, params=None):
190201
super(VisionTransformer, self).__init__()
191202
self._params = params
192203

193-
self.embedding = nn.Conv2d(3, self._params.emb_dim, kernel_size=self.patch_size, stride=self.patch_size)
204+
if self._params.resnet:
205+
self.resnet = self._params.resnet()
206+
self.embedding = nn.Conv2d(self.resnet.width * 16,
207+
self._params.emb_dim,
208+
kernel_size=1,
209+
stride=1)
210+
else:
211+
self.embedding = nn.Conv2d(3,
212+
self._params.emb_dim,
213+
kernel_size=self.patch_size,
214+
stride=self.patch_size)
194215
# class token
195216
self.cls_token = nn.Parameter(torch.zeros(1, 1, self._params.emb_dim))
196217

197218
# transformer
198-
self.transformer = Encoder(num_patches=self.num_patches,
199-
emb_dim=self._params.emb_dim,
200-
mlp_dim=self._params.mlp_dim,
201-
num_layers=self._params.num_layers,
202-
num_heads=self._params.num_heads,
203-
dropout_rate=self._params.dropout_rate,
204-
attn_dropout_rate=self._params.attn_dropout_rate)
219+
self.transformer = Encoder(
220+
num_patches=self.num_patches,
221+
emb_dim=self._params.emb_dim,
222+
mlp_dim=self._params.mlp_dim,
223+
num_layers=self._params.num_layers,
224+
num_heads=self._params.num_heads,
225+
dropout_rate=self._params.dropout_rate,
226+
attn_dropout_rate=self._params.attn_dropout_rate)
205227

206228
# classfier
207-
self.classifier = nn.Linear(self._params.emb_dim, self._params.num_classes)
229+
self.classifier = nn.Linear(self._params.emb_dim,
230+
self._params.num_classes)
208231

209232
@property
210233
def image_size(self):
@@ -218,10 +241,16 @@ def patch_size(self):
218241
def num_patches(self):
219242
h, w = self.image_size
220243
fh, fw = self.patch_size
221-
gh, gw = h // fh, w // fw
244+
if hasattr(self, 'resnet'):
245+
gh, gw = h // fh // self.resnet.downsample, w // fw // self.resnet.downsample
246+
else:
247+
gh, gw = h // fh, w // fw
222248
return gh * gw
223249

224250
def extract_features(self, x):
251+
if hasattr(self, 'resnet'):
252+
x = self.resnet(x)
253+
225254
emb = self.embedding(x) # (n, c, gh, gw)
226255
emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c)
227256
b, h, w, c = emb.shape
@@ -266,7 +295,12 @@ def from_name(cls, model_name, in_channels=3, **override_params):
266295
return model
267296

268297
@classmethod
269-
def from_pretrained(cls, model_name, weights_path=None, in_channels=3, num_classes=1000, **override_params):
298+
def from_pretrained(cls,
299+
model_name,
300+
weights_path=None,
301+
in_channels=3,
302+
num_classes=1000,
303+
**override_params):
270304
"""create an vision transformer model according to name.
271305
Args:
272306
model_name (str): Name for vision transformer.
@@ -288,8 +322,13 @@ def from_pretrained(cls, model_name, weights_path=None, in_channels=3, num_class
288322
Returns:
289323
A pretrained vision transformer model.
290324
"""
291-
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
292-
load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000))
325+
model = cls.from_name(model_name,
326+
num_classes=num_classes,
327+
**override_params)
328+
load_pretrained_weights(model,
329+
model_name,
330+
weights_path=weights_path,
331+
load_fc=(num_classes == 1000))
293332
model._change_in_channels(in_channels)
294333
return model
295334

@@ -302,15 +341,24 @@ def _check_model_name_is_valid(cls, model_name):
302341
bool: Is a valid name or not.
303342
"""
304343
if model_name not in VALID_MODELS:
305-
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
344+
raise ValueError('model_name should be one of: ' +
345+
', '.join(VALID_MODELS))
306346

307347
def _change_in_channels(self, in_channels):
308348
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
309349
Args:
310350
in_channels (int): Input data's channel number.
311351
"""
312352
if in_channels != 3:
313-
self.embedding = nn.Conv2d(in_channels,
314-
self._params.emb_dim,
315-
kernel_size=self.patch_size,
316-
stride=self.patch_size)
353+
if hasattr(self, 'resnet'):
354+
self.resnet.root['conv'] = StdConv2d(in_channels,
355+
self.resnet.width,
356+
kernel_size=7,
357+
stride=2,
358+
bias=False,
359+
padding=3)
360+
else:
361+
self.embedding = nn.Conv2d(in_channels,
362+
self._params.emb_dim,
363+
kernel_size=self.patch_size,
364+
stride=self.patch_size)

0 commit comments

Comments
 (0)