Skip to content

Commit 2f6d890

Browse files
committed
Add DINOv2 model
1 parent 6c98708 commit 2f6d890

File tree

11 files changed

+918
-0
lines changed

11 files changed

+918
-0
lines changed

oml/configs/extractor/vit_v2.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
name: vit_v2
2+
args:
3+
arch: vits14
4+
normalise_features: False
5+
weights: null

oml/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
from oml.models.resnet.extractor import ResnetExtractor
88
from oml.models.vit_clip.extractor import ViTCLIPExtractor
99
from oml.models.vit_dino.extractor import ViTExtractor
10+
from oml.models.vit_dinov2.extractor import ViTExtractor_v2
1011
from oml.models.vit_unicom.extractor import ViTUnicomExtractor

oml/models/vit_dinov2/__init__.py

Whitespace-only changes.

oml/models/vit_dinov2/external/__init__.py

Whitespace-only changes.
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# References:
2+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/config.py
3+
4+
import os
5+
6+
import torch
7+
8+
# use torch.scaled_dot_product_attention where possible
9+
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
10+
_USE_FUSED_ATTN = int(os.environ.get("USE_FUSED_ATTN", 0))
11+
12+
13+
def use_fused_attn() -> bool:
14+
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
15+
if not _HAS_FUSED_ATTN:
16+
return False
17+
return _USE_FUSED_ATTN > 0
+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# type: ignore
2+
# flake8: noqa
3+
4+
5+
from enum import Enum
6+
from typing import Union
7+
8+
import torch
9+
10+
import oml.models.vit_dinov2.external.vision_transformer as vits
11+
from oml.const import CKPT_SAVE_ROOT
12+
13+
# ============== CODE FROM DINOV2 ==============
14+
# https://github.com/facebookresearch/dinov2/blob/main/dinov2/hub/backbones.py
15+
16+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
17+
18+
19+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
20+
compact_arch_name = arch_name.replace("_", "")[:4]
21+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
22+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
23+
24+
25+
class Weights(Enum):
26+
LVD142M = "LVD142M"
27+
28+
29+
def _make_dinov2_model(
30+
*,
31+
arch_name: str = "vit_large",
32+
img_size: int = 518,
33+
patch_size: int = 14,
34+
init_values: float = 1.0,
35+
ffn_layer: str = "mlp",
36+
block_chunks: int = 0,
37+
num_register_tokens: int = 0,
38+
interpolate_antialias: bool = False,
39+
interpolate_offset: float = 0.1,
40+
pretrained: bool = True,
41+
weights: Union[Weights, str] = Weights.LVD142M,
42+
**kwargs,
43+
):
44+
if isinstance(weights, str):
45+
try:
46+
weights = Weights[weights]
47+
except KeyError as e:
48+
raise AssertionError(f"Unsupported weights: {weights}") from e
49+
50+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
51+
vit_kwargs = {
52+
"img_size": img_size,
53+
"patch_size": patch_size,
54+
"init_values": init_values,
55+
"ffn_layer": ffn_layer,
56+
"block_chunks": block_chunks,
57+
"num_register_tokens": num_register_tokens,
58+
"interpolate_antialias": interpolate_antialias,
59+
"interpolate_offset": interpolate_offset,
60+
}
61+
vit_kwargs.update(**kwargs)
62+
model = vits.__dict__[arch_name](**vit_kwargs)
63+
64+
if pretrained:
65+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
66+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
67+
state_dict = torch.hub.load_state_dict_from_url(
68+
url, map_location="cpu", model_dir=str(CKPT_SAVE_ROOT.resolve())
69+
)
70+
model.load_state_dict(state_dict, strict=True)
71+
72+
return model
73+
74+
75+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
76+
"""
77+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
78+
"""
79+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
80+
81+
82+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
83+
"""
84+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
85+
"""
86+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
87+
88+
89+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
90+
"""
91+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
92+
"""
93+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
94+
95+
96+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
97+
"""
98+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset
99+
"""
100+
return _make_dinov2_model(
101+
arch_name="vit_small",
102+
pretrained=pretrained,
103+
weights=weights,
104+
num_register_tokens=4,
105+
interpolate_antialias=True,
106+
interpolate_offset=0.0,
107+
**kwargs,
108+
)
109+
110+
111+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
112+
"""
113+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset
114+
"""
115+
return _make_dinov2_model(
116+
arch_name="vit_base",
117+
pretrained=pretrained,
118+
weights=weights,
119+
num_register_tokens=4,
120+
interpolate_antialias=True,
121+
interpolate_offset=0.0,
122+
**kwargs,
123+
)
124+
125+
126+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
127+
"""
128+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset
129+
"""
130+
return _make_dinov2_model(
131+
arch_name="vit_large",
132+
pretrained=pretrained,
133+
weights=weights,
134+
num_register_tokens=4,
135+
interpolate_antialias=True,
136+
interpolate_offset=0.0,
137+
**kwargs,
138+
)

0 commit comments

Comments
 (0)