Skip to content

Commit 492947d

Browse files
authored
Merge pull request #2113 from gau-nernst/tinyclip
Add TinyCLIP
2 parents 111fad1 + 256cf19 commit 492947d

File tree

1 file changed

+65
-4
lines changed

1 file changed

+65
-4
lines changed

timm/models/vision_transformer.py

+65-4
Original file line numberDiff line numberDiff line change
@@ -964,11 +964,13 @@ def _convert_openai_clip(
964964
v = v.unsqueeze(0)
965965
if v.shape[1] != model.pos_embed.shape[1]:
966966
# To resize pos embedding when using model at different size from pretrained weights
967-
v = resize_pos_embed(
967+
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
968+
else getattr(model, 'num_prefix_tokens', 1)
969+
v = resample_abs_pos_embed(
968970
v,
969-
model.pos_embed,
970-
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
971-
model.patch_embed.grid_size
971+
new_size=model.patch_embed.grid_size,
972+
num_prefix_tokens=num_prefix_tokens,
973+
verbose=True,
972974
)
973975
out_dict[k] = v
974976
return out_dict
@@ -1735,6 +1737,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17351737
input_size=(3, 384, 384),
17361738
num_classes=0),
17371739

1740+
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
1741+
hf_hub_id='timm/',
1742+
hf_hub_filename='open_clip_pytorch_model.bin',
1743+
license='mit',
1744+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1745+
'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
1746+
hf_hub_id='timm/',
1747+
hf_hub_filename='open_clip_pytorch_model.bin',
1748+
license='mit',
1749+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1750+
'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
1751+
hf_hub_id='timm/',
1752+
hf_hub_filename='open_clip_pytorch_model.bin',
1753+
license='mit',
1754+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1755+
'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
1756+
hf_hub_id='timm/',
1757+
hf_hub_filename='open_clip_pytorch_model.bin',
1758+
license='mit',
1759+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1760+
17381761
'vit_medium_patch16_reg4_256': _cfg(
17391762
input_size=(3, 256, 256)),
17401763
'vit_medium_patch16_reg4_gap_256': _cfg(
@@ -2073,6 +2096,44 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans
20732096
return model
20742097

20752098

2099+
@register_model
2100+
def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
2101+
# TinyCLIP 8M
2102+
model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
2103+
model = _create_vision_transformer(
2104+
'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2105+
return model
2106+
2107+
2108+
@register_model
2109+
def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
2110+
# TinyCLIP 40M
2111+
model_args = dict(
2112+
patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
2113+
model = _create_vision_transformer(
2114+
'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2115+
return model
2116+
2117+
2118+
@register_model
2119+
def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
2120+
# TinyCLIP 39M
2121+
model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
2122+
model = _create_vision_transformer(
2123+
'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2124+
return model
2125+
2126+
2127+
@register_model
2128+
def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
2129+
# TinyCLIP 61M
2130+
model_args = dict(
2131+
patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm)
2132+
model = _create_vision_transformer(
2133+
'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
2134+
return model
2135+
2136+
20762137
@register_model
20772138
def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
20782139
""" ViT-B/32 CLIP image tower @ 224x224

0 commit comments

Comments
 (0)