Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0a14ca6
changed convert.py, added explore-conversion_21k.py script, added log…
arkel23 Dec 18, 2020
fc6722f
restructured directory and made it so that instead of downloading pth…
arkel23 Dec 18, 2020
4d31187
restructured, deleted jax_to_pytorch and moved to utils.py and made s…
arkel23 Dec 18, 2020
547f8c9
deleted jax_to_pytorch and added the py to download the models
arkel23 Dec 18, 2020
0598f15
deleted jax_to_pytorch and combined relevant files into pytorc_pretra…
arkel23 Dec 18, 2020
69c2138
added some inference scripts and some annotations in transformer.py
arkel23 Dec 18, 2020
0f3ab14
added an example for cifar-10 dataset
arkel23 Dec 18, 2020
eb77c3f
added files and example to visualize attention, modified transformer …
arkel23 Dec 20, 2020
3619b09
changes to allow for visualization and compatibility with torchsummar…
arkel23 Dec 22, 2020
ed647de
Update README.md
arkel23 Dec 22, 2020
b1a543d
Update README.md
arkel23 Dec 22, 2020
5c1a017
Update README.md
arkel23 Dec 22, 2020
e0405c1
Update README.md
arkel23 Dec 22, 2020
23160eb
changed downloaded models to only standalone vits pretrained on image…
arkel23 May 19, 2021
5b493f9
changes to load pretrained weights and load from configuration yaml a…
arkel23 May 19, 2021
6d08331
modified readme to describe how to load partial
arkel23 May 19, 2021
0c734e8
added load_fc_layer as argument in case want to retrieve transformer …
arkel23 Jul 21, 2021
9670b61
directly use config so changes in seq len are reflected in original c…
arkel23 Jul 21, 2021
b5a5bcb
readability and removed unused variables
arkel23 Jul 22, 2021
c381af9
add functions to retrieve patchified image before inputting into vit
arkel23 Jul 25, 2021
f51f41d
changes to retrieve intermediate representations
arkel23 Jul 26, 2021
af11a2f
added text modality from vilt
arkel23 Aug 6, 2021
542e480
added vocab size to config
arkel23 Aug 7, 2021
3fb71c9
added default vocab size
arkel23 Aug 11, 2021
4ea0abd
reorganized and separated extract features into a function
arkel23 Aug 22, 2021
b0b5b14
updated structure to have less ifs
arkel23 Aug 22, 2021
b29c2cb
added options to configuration dic
arkel23 Aug 23, 2021
54b4119
put all back into forward mode
arkel23 Aug 23, 2021
7557f42
updated small models
arkel23 Aug 26, 2021
9e731ac
updated configs
arkel23 Aug 31, 2021
9322eee
added pretrained ckpts from google for s16, s32, and ti16
arkel23 Sep 11, 2021
eb739d0
added patch 4 and 8 configs
arkel23 Sep 13, 2021
fd2f4ee
added option to not load cls token
arkel23 Jan 18, 2022
ace1193
typo
arkel23 Jan 23, 2022
a0c9e76
layernorm regardless of fc or not
arkel23 Jan 24, 2022
09a598c
option to turn off layernorm before head, and pass only first token t…
arkel23 Feb 15, 2022
5f08020
updated names and default configs for vits
arkel23 Mar 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ tmp
*.pkl
.vscode
*.npy
*.npz
*.npz*
*.pth
.nfs*
examples/**/data/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
194 changes: 48 additions & 146 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,176 +1,78 @@
# ViT PyTorch

### Quickstart
Forked from [Luke Melas-Kyriazi repository](https://github.com/lukemelas/PyTorch-Pretrained-ViT).

Install with `pip install pytorch_pretrained_vit` and load a pretrained ViT with:
```python
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
```

Or find a Google Colab example [here](https://colab.research.google.com/drive/1muZ4QFgVfwALgqmrfOkp7trAvqDemckO?usp=sharing).

### Overview
This repository contains an op-for-op PyTorch reimplementation of the [Visual Transformer](https://openreview.net/forum?id=YicbFdNTTy) architecture from [Google](https://github.com/google-research/vision_transformer), along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.

At the moment, you can easily:
* Load pretrained ViT models
* Evaluate on ImageNet or your own data
* Finetune ViT on your own dataset

_(Upcoming features)_ Coming soon:
* Train ViT from scratch on ImageNet (1K)
* Export to ONNX for efficient inference

### Table of contents
1. [About ViT](#about-vit)
2. [About ViT-PyTorch](#about-vit-pytorch)
3. [Installation](#installation)
4. [Usage](#usage)
* [Load pretrained models](#loading-pretrained-models)
* [Example: Classify](#example-classification)
<!-- * [Example: Extract features](#example-feature-extraction) -->
<!-- * [Example: Export to ONNX](#example-export) -->
6. [Contributing](#contributing)

### About ViT

Visual Transformers (ViT) are a straightforward application of the [transformer architecture](https://arxiv.org/abs/1706.03762) to image classification. Even in computer vision, it seems, attention is all you need.

The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like [BERT](https://arxiv.org/abs/1810.04805)), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits.
ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

<div style="text-align: center; padding: 10px">
<img src="https://raw.githubusercontent.com/google-research/vision_transformer/master/figure1.png" width="100%" style="max-width: 300px; margin: auto"/>
</div>


### About ViT-PyTorch

ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the [original Jax implementation](https://github.com/google-research/vision_transformer), so that it's easy to load Jax-pretrained weights.

At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.
### Setup

### Installation

Install with pip:
```bash
pip install pytorch_pretrained_vit
```

Or from source:
```bash
git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
git clone https://github.com/arkel23/PyTorch-Pretrained-ViT.git
cd PyTorch-Pretrained-ViT
pip install -e .
python download_convert_models.py # can modify to download different models, by default it downloads all 5 ViTs pretrained on ImageNet21k
```

### Usage

#### Loading pretrained models

Loading a pretrained model is easy:
```python
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
```

Details about the models are below:

| *Name* |* Pretrained on *|*Finetuned on*|*Available? *|
|:-----------------:|:---------------:|:------------:|:-----------:|
| `B_16` | ImageNet-21k | - | ✓ |
| `B_32` | ImageNet-21k | - | ✓ |
| `L_16` | ImageNet-21k | - | - |
| `L_32` | ImageNet-21k | - | ✓ |
| `B_16_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ |
| `B_32_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ |
| `L_16_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ |
| `L_32_imagenet1k` | ImageNet-21k | ImageNet-1k | ✓ |

#### Custom ViT

Loading custom configurations is just as easy:
```python
from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)
```
from pytorch_pretrained_vit import ViT, ViTConfigExtended, PRETRAINED_CONFIGS

#### Example: Classification

Below is a simple, complete example. It may also be found as a Jupyter notebook in `examples/simple` or as a [Colab Notebook]().
<!-- TODO: new Colab -->

```python
import json
from PIL import Image
import torch
from torchvision import transforms

# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])

# Classify
with torch.no_grad():
outputs = model(img)
print(outputs.shape) # (1, 1000)
model_name = 'B_16'
def_config = PRETRAINED_CONFIGS['{}'.format(model_name)]['config']
configuration = ViTConfigExtended(**def_config)
model = ViT(configuration, name=model_name, pretrained=True, load_repr_layer=False, ret_attn_scores=False)
```

<!-- #### Example: Feature Extraction
### Changes compared to original

You can easily extract features with `model.extract_features`:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

# ... image preprocessing as in the classification example ...
print(img.shape) # torch.Size([1, 3, 384, 384])

features = model.extract_features(img)
print(features.shape) # torch.Size([1, 1280, 7, 7])
``` -->

<!-- #### Example: Export to ONNX

Exporting to ONNX for deploying to production is now simple:
```python
import torch
from efficientnet_pytorch import EfficientNet
* Added support for 'H-14' and L'16' ViT models.
* Added support for downloading the models directly from Google's cloud storage.
* Corrected the Jax to Pytorch weights transformation. Previous methodology would lead to .pth state_dict files without the 'representation layer'. `ViT('load_repr_layer'=True)` would lead to an error. If only interested in inference the representation layer was unnecessary as discussed in the original paper for the Vision Transformer, but for other applications and experiments it may be useful so I added a `download_convert_models.py` to first download the required models, convert them with all the weights, and then you can completely tune the parameters.
* Added support for visualizing attention, by returning the scores values in the multi-head self-attention layers. The visualizing script was mostly taken from [jeonsworld/ViT-pytorch repository](https://github.com/jeonsworld/ViT-pytorch).
* Added examples for inference (single image), and fine-tuning/training (using CIFAR-10).
* Modified loading of models by using configurations similar to HuggingFace's Transformers.
```
# Change the default configuration by accessing individual attributes
configuration.image_size = 128
configuration.num_classes = 10
configuration.num_hidden_layers = 3
model = ViT_modified(config=configuration, name='B_16', pretrained=True)
# for another example see examples/configurations/load_configs.py
```
* Added support to partially load ViT
```
model = ViT(config=configuration, name='B_16')
pretrained_mode = 'full_tokenizer'
weights_path = "/hdd/edwin/support/torch/hub/checkpoints/B_16.pth"
model.load_partial(weights_path=weights_path, pretrained_image_size=configuration.pretrained_image_size,
pretrained_mode=pretrained_mode, verbose=True)
for pretrained_mode in ['full_tokenizer', 'patchprojection', 'posembeddings', 'clstoken',
'patchandposembeddings', 'patchandclstoken', 'posembeddingsandclstoken']:
model.load_partial(weights_path=weights_path,
pretrained_image_size=configuration.pretrained_image_size, pretrained_mode=pretrained_mode, verbose=True)
```

model = EfficientNet.from_pretrained('efficientnet-b1')
dummy_input = torch.randn(10, 3, 240, 240)
### About

model.set_swish(memory_efficient=False)
torch.onnx.export(model, dummy_input, "test-b1.onnx", verbose=True)
```
This repository contains an op-for-op PyTorch reimplementation of the [Vision Transformer](https://openreview.net/forum?id=YicbFdNTTy) architecture from [Google](https://github.com/google-research/vision_transformer), along with pre-trained models and examples.

[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example. -->

Visual Transformers (ViT) are a straightforward application of the [transformer architecture](https://arxiv.org/abs/1706.03762) to image classification. Even in computer vision, it seems, attention is all you need.

#### ImageNet
The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like [BERT](https://arxiv.org/abs/1810.04805)), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits.
ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

See `examples/imagenet` for details about evaluating on ImageNet.
<div style="text-align: center; padding: 10px">
<img src="https://raw.githubusercontent.com/google-research/vision_transformer/master/figure1.png" width="100%" style="max-width: 300px; margin: auto"/>
</div>

#### Credit

Other great repositories with this model include:
- [Google Research's repo](https://github.com/google-research/vision_transformer)
- [Ross Wightman's repo](https://github.com/rwightman/pytorch-image-models)
- [Phil Wang's repo](https://github.com/lucidrains/vit-pytorch)
- [Eunkwang Jeon's repo](https://github.com/jeonsworld/ViT-pytorch)
- [Luke Melas-Kyriazi repo](https://github.com/lukemelas/PyTorch-Pretrained-ViT)

### Contributing

Expand Down
8 changes: 8 additions & 0 deletions download_convert_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pytorch_pretrained_vit import ViT, ViTConfigExtended, PRETRAINED_CONFIGS

models_list = ['B_16', 'B_16_in1k']
for model_name in models_list:
def_config = PRETRAINED_CONFIGS['{}'.format(model_name)]['config']
configuration = ViTConfigExtended(**def_config)
model = ViT(configuration, name=model_name, pretrained=True, load_repr_layer=True)

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/attention/attention_data/img.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions examples/attention/visualize_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
from urllib.request import urlretrieve

import torch
from torchvision import transforms

from pytorch_pretrained_vit import ViT

models_list = ['B_16', 'B_32', 'L_32', 'B_16_imagenet1k', 'B_32_imagenet1k', 'L_16_imagenet1k', 'L_32_imagenet1k']
model_name = models_list[3]
model = ViT(model_name, pretrained=True, visualize=True)

# Test Image
os.makedirs("attention_data", exist_ok=True)
img_url = "https://images.mypetlife.co.kr/content/uploads/2019/04/09192811/welsh-corgi-1581119_960_720.jpg"
urlretrieve(img_url, "attention_data/img.jpg")

transform = transforms.Compose([
transforms.Resize(model.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("attention_data/img.jpg")
im = Image.open('img.jpg')
x = transform(im)
x.size()

# Load class names
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(1000)]

# Classify
model.eval()
with torch.no_grad():
outputs, att_mat = model(x.unsqueeze(0))

outputs = outputs.squeeze(0)
print(outputs.shape)
print(len(att_mat))
print(att_mat[0].shape)
#print(outputs, att_mat)
#print('logits_size and att_mat sizes: ', outputs.shape, att_mat.shape)

att_mat = torch.stack(att_mat).squeeze(1)
print(att_mat.shape)

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)
print(att_mat.shape)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
print('residual_att and aug_att_mat sizes: ', residual_att.shape, aug_att_mat.shape)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

# Attention from the output token to the input space.
v = joint_attentions[-1] # last layer output attention map
print('joint_attentions and last layer (v) sizes: ', joint_attentions.shape, v.shape)
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
print(mask.shape)
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
print(mask.shape)
result = (mask * im).astype("uint8")

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
_ = ax2.imshow(result)

print('-----')
for idx in torch.topk(outputs, k=3).indices.tolist():
prob = torch.softmax(outputs, -1)[idx].item()
print('[{idx}] {label:<75} ({p:.2f}%)'.format(idx=idx, label=labels_map[idx], p=prob*100))

for i, v in enumerate(joint_attentions):
# Attention from the output token to the input space.
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
ax1.set_title('Original')
title = 'AttentionMap_Layer{}'.format(i+1)
ax2.set_title(title)
_ = ax1.imshow(im)
_ = ax2.imshow(result)
plt.savefig(os.path.join('attention_data', title))
Loading