Commit 860d055 1 parent 9462553 commit 860d055 Copy full SHA for 860d055
File tree 1 file changed +4
-3
lines changed
1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -161,9 +161,10 @@ def __init__(
161
161
else :
162
162
state_dict = torch .load (Path (weights ), map_location = "cpu" )
163
163
state_dict = state_dict .get ("state_dict" , state_dict )
164
- state_dict = remove_criterion_in_state_dict (state_dict )
165
- state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "class_embedding" )
166
- state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
164
+
165
+ state_dict = take_visual_part_of_vit_clip (state_dict , needed_keys = self .visual .state_dict ().keys ())
166
+ state_dict = remove_prefix_from_state_dict (state_dict , trial_key = "class_embedding" )
167
+ state_dict = remove_criterion_in_state_dict (state_dict )
167
168
168
169
self .visual .load_state_dict (state_dict = state_dict , strict = True )
169
170
You can’t perform that action at this time.
0 commit comments