Skip to content

Commit 2e30995

Browse files
authored
Some fixes for CLIP
Some fixes for CLIP
1 parent 9462553 commit 2e30995

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

oml/models/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def find_prefix_in_state_dict(state_dict: TStateDict, trial_key: str) -> str:
1616
k0 = [k for k in state_dict.keys() if trial_key in k][0]
1717
prefix = k0[: k0.index(trial_key)]
1818

19-
assert all(k.startswith(prefix) for k in state_dict.keys())
19+
keys_not_starting_with_prefix = list(filter(lambda x: not x.startswith(prefix), state_dict.keys()))
20+
assert (
21+
not keys_not_starting_with_prefix
22+
), f"There are keys not starting from the found prefix {prefix}: {keys_not_starting_with_prefix}"
2023

2124
return prefix
2225

oml/models/vit_clip/extractor.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,6 @@ class ViTCLIPExtractor(IExtractor):
9393
"fname": "openai_vitl14_224.ckpt",
9494
"init_args": {"arch": "vitl14_224", "normalise_features": False},
9595
},
96-
"openai_vitl14_336": {
97-
"url": f"{_OPENAI_URL}/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
98-
"hash": "b311058cae50cb10fbfa2a44231c9473",
99-
"is_jitted": True,
100-
"fname": "openai_vitl14_336.ckpt",
101-
"init_args": {"arch": "vitl14_336", "normalise_features": False},
102-
},
10396
# checkpoints pretrained by SberbankAI
10497
"sber_vitb16_224": {
10598
"url": f"{_SBER_URL}/ruclip-vit-base-patch16-224/resolve/main/pytorch_model.bin",
@@ -161,9 +154,10 @@ def __init__(
161154
else:
162155
state_dict = torch.load(Path(weights), map_location="cpu")
163156
state_dict = state_dict.get("state_dict", state_dict)
164-
state_dict = remove_criterion_in_state_dict(state_dict)
165-
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="class_embedding")
166-
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())
157+
158+
state_dict = remove_criterion_in_state_dict(state_dict)
159+
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="conv1.weight")
160+
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())
167161

168162
self.visual.load_state_dict(state_dict=state_dict, strict=True)
169163

oml/registry/transforms.py

-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def get_transforms_by_cfg(cfg: TCfg) -> TTransforms:
6767
"openai_vitb32_224": get_normalisation_resize_albu_clip(im_size=224),
6868
"openai_vitb16_224": get_normalisation_resize_albu_clip(im_size=224),
6969
"openai_vitl14_224": get_normalisation_resize_albu_clip(im_size=224),
70-
"openai_vitl14_336": get_normalisation_resize_albu_clip(im_size=224),
7170
"vits16_inshop": get_normalisation_resize_hypvit(im_size=224, crop_size=224),
7271
"vits16_sop": get_normalisation_resize_hypvit(im_size=224, crop_size=224),
7372
"vits16_cars": get_normalisation_resize_albu(im_size=224),

0 commit comments

Comments
 (0)