@@ -151,13 +151,13 @@ def __init__(
151
151
visual = torch .jit .load (Path (weights ), map_location = "cpu" ).visual
152
152
patch_device_and_float (visual , device = "cpu" )
153
153
state_dict = visual .state_dict ()
154
+
154
155
else :
155
156
state_dict = torch .load (Path (weights ), map_location = "cpu" )
156
157
state_dict = state_dict .get ("state_dict" , state_dict )
157
-
158
- state_dict = remove_criterion_in_state_dict (state_dict )
159
- state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "conv1.weight" )
160
- state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
158
+ state_dict = remove_criterion_in_state_dict (state_dict )
159
+ state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
160
+ state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "conv1.weight" )
161
161
162
162
self .visual .load_state_dict (state_dict = state_dict , strict = True )
163
163
@@ -178,8 +178,9 @@ def feat_dim(self) -> int:
178
178
179
179
def take_visual_part_of_vit_clip (state_dict : TStateDict , needed_keys : Iterable [str ]) -> TStateDict :
180
180
for k in list (state_dict ):
181
- if k .startswith ("visual." ):
182
- state_dict [k .lstrip ("visual" )[1 :]] = state_dict .pop (k )
181
+ if "visual" in k :
182
+ new_key = k [k .find ("visual" ) + len ("visual" ) + 1 :]
183
+ state_dict [new_key ] = state_dict .pop (k )
183
184
state_dict = filter_state_dict (state_dict , needed_keys = needed_keys )
184
185
return state_dict
185
186
0 commit comments