@@ -93,13 +93,6 @@ class ViTCLIPExtractor(IExtractor):
93
93
"fname" : "openai_vitl14_224.ckpt" ,
94
94
"init_args" : {"arch" : "vitl14_224" , "normalise_features" : False },
95
95
},
96
- "openai_vitl14_336" : {
97
- "url" : f"{ _OPENAI_URL } /3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" ,
98
- "hash" : "b311058cae50cb10fbfa2a44231c9473" ,
99
- "is_jitted" : True ,
100
- "fname" : "openai_vitl14_336.ckpt" ,
101
- "init_args" : {"arch" : "vitl14_336" , "normalise_features" : False },
102
- },
103
96
# checkpoints pretrained by SberbankAI
104
97
"sber_vitb16_224" : {
105
98
"url" : f"{ _SBER_URL } /ruclip-vit-base-patch16-224/resolve/main/pytorch_model.bin" ,
@@ -161,9 +154,10 @@ def __init__(
161
154
else :
162
155
state_dict = torch .load (Path (weights ), map_location = "cpu" )
163
156
state_dict = state_dict .get ("state_dict" , state_dict )
164
- state_dict = remove_criterion_in_state_dict (state_dict )
165
- state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "class_embedding" )
166
- state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
157
+
158
+ state_dict = remove_criterion_in_state_dict (state_dict )
159
+ state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "conv1.weight" )
160
+ state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
167
161
168
162
self .visual .load_state_dict (state_dict = state_dict , strict = True )
169
163
0 commit comments