@@ -964,11 +964,13 @@ def _convert_openai_clip(
964
964
v = v .unsqueeze (0 )
965
965
if v .shape [1 ] != model .pos_embed .shape [1 ]:
966
966
# To resize pos embedding when using model at different size from pretrained weights
967
- v = resize_pos_embed (
967
+ num_prefix_tokens = 0 if getattr (model , 'no_embed_class' , False ) \
968
+ else getattr (model , 'num_prefix_tokens' , 1 )
969
+ v = resample_abs_pos_embed (
968
970
v ,
969
- model .pos_embed ,
970
- 0 if getattr ( model , 'no_embed_class' ) else getattr ( model , ' num_prefix_tokens' , 1 ) ,
971
- model . patch_embed . grid_size
971
+ new_size = model .patch_embed . grid_size ,
972
+ num_prefix_tokens = num_prefix_tokens ,
973
+ verbose = True ,
972
974
)
973
975
out_dict [k ] = v
974
976
return out_dict
@@ -1735,6 +1737,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1735
1737
input_size = (3 , 384 , 384 ),
1736
1738
num_classes = 0 ),
1737
1739
1740
+ 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
1741
+ hf_hub_id = 'timm/' ,
1742
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1743
+ license = 'mit' ,
1744
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1745
+ 'vit_medium_patch32_clip_224.tinyclip_laion400m' : _cfg (
1746
+ hf_hub_id = 'timm/' ,
1747
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1748
+ license = 'mit' ,
1749
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1750
+ 'vit_medium_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
1751
+ hf_hub_id = 'timm/' ,
1752
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1753
+ license = 'mit' ,
1754
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1755
+ 'vit_betwixt_patch32_clip_224.tinyclip_laion400m' : _cfg (
1756
+ hf_hub_id = 'timm/' ,
1757
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1758
+ license = 'mit' ,
1759
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1760
+
1738
1761
'vit_medium_patch16_reg4_256' : _cfg (
1739
1762
input_size = (3 , 256 , 256 )),
1740
1763
'vit_medium_patch16_reg4_gap_256' : _cfg (
@@ -2073,6 +2096,44 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans
2073
2096
return model
2074
2097
2075
2098
2099
+ @register_model
2100
+ def vit_xsmall_patch16_clip_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2101
+ # TinyCLIP 8M
2102
+ model_args = dict (embed_dim = 256 , depth = 10 , num_heads = 4 , pre_norm = True , norm_layer = nn .LayerNorm )
2103
+ model = _create_vision_transformer (
2104
+ 'vit_xsmall_patch16_clip_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2105
+ return model
2106
+
2107
+
2108
+ @register_model
2109
+ def vit_medium_patch32_clip_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2110
+ # TinyCLIP 40M
2111
+ model_args = dict (
2112
+ patch_size = 32 , embed_dim = 512 , depth = 12 , num_heads = 8 , pre_norm = True , norm_layer = nn .LayerNorm )
2113
+ model = _create_vision_transformer (
2114
+ 'vit_medium_patch32_clip_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2115
+ return model
2116
+
2117
+
2118
+ @register_model
2119
+ def vit_medium_patch16_clip_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2120
+ # TinyCLIP 39M
2121
+ model_args = dict (embed_dim = 512 , depth = 12 , num_heads = 8 , pre_norm = True , norm_layer = nn .LayerNorm )
2122
+ model = _create_vision_transformer (
2123
+ 'vit_medium_patch16_clip_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2124
+ return model
2125
+
2126
+
2127
+ @register_model
2128
+ def vit_betwixt_patch32_clip_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2129
+ # TinyCLIP 61M
2130
+ model_args = dict (
2131
+ patch_size = 32 , embed_dim = 640 , depth = 12 , num_heads = 10 , pre_norm = True , norm_layer = nn .LayerNorm )
2132
+ model = _create_vision_transformer (
2133
+ 'vit_betwixt_patch32_clip_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2134
+ return model
2135
+
2136
+
2076
2137
@register_model
2077
2138
def vit_base_patch32_clip_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2078
2139
""" ViT-B/32 CLIP image tower @ 224x224
0 commit comments