Skip to content

Commit 7d80e80

Browse files
authored
Model hot fix
Model hot fix
1 parent 2e30995 commit 7d80e80

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

oml/models/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ def remove_criterion_in_state_dict(state_dict: TStateDict) -> TStateDict:
1313

1414

1515
def find_prefix_in_state_dict(state_dict: TStateDict, trial_key: str) -> str:
16-
k0 = [k for k in state_dict.keys() if trial_key in k][0]
16+
keys_starting_with_trial_key = [k for k in state_dict.keys() if trial_key in k]
17+
assert keys_starting_with_trial_key, (
18+
f"There are no keys starting from {trial_key}.\n" f"The existing keys are: {list(state_dict.keys())}"
19+
)
20+
21+
k0 = keys_starting_with_trial_key[0]
1722
prefix = k0[: k0.index(trial_key)]
1823

1924
keys_not_starting_with_prefix = list(filter(lambda x: not x.startswith(prefix), state_dict.keys()))

oml/models/vit_clip/extractor.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,13 @@ def __init__(
151151
visual = torch.jit.load(Path(weights), map_location="cpu").visual
152152
patch_device_and_float(visual, device="cpu")
153153
state_dict = visual.state_dict()
154+
154155
else:
155156
state_dict = torch.load(Path(weights), map_location="cpu")
156157
state_dict = state_dict.get("state_dict", state_dict)
157-
158-
state_dict = remove_criterion_in_state_dict(state_dict)
159-
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="conv1.weight")
160-
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())
158+
state_dict = remove_criterion_in_state_dict(state_dict)
159+
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())
160+
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="conv1.weight")
161161

162162
self.visual.load_state_dict(state_dict=state_dict, strict=True)
163163

@@ -178,8 +178,9 @@ def feat_dim(self) -> int:
178178

179179
def take_visual_part_of_vit_clip(state_dict: TStateDict, needed_keys: Iterable[str]) -> TStateDict:
180180
for k in list(state_dict):
181-
if k.startswith("visual."):
182-
state_dict[k.lstrip("visual")[1:]] = state_dict.pop(k)
181+
if "visual" in k:
182+
new_key = k[k.find("visual") + len("visual") + 1 :]
183+
state_dict[new_key] = state_dict.pop(k)
183184
state_dict = filter_state_dict(state_dict, needed_keys=needed_keys)
184185
return state_dict
185186

0 commit comments

Comments
 (0)