Skip to content

Commit 9462553

Browse files
authored
Fixes for CLIP
1 parent 58df7a3 commit 9462553

File tree

6 files changed

+26
-30
lines changed

6 files changed

+26
-30
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ repos:
5050
- id: mypy
5151
additional_dependencies: [types-requests==2.25.9]
5252

53-
# check for unused code
54-
- repo: https://github.com/jendrikseipp/vulture
55-
rev: v2.6
56-
hooks:
57-
- id: vulture
58-
args: [--min-confidence=100, --sort-by-size, .]
53+
# check for unused code (todo: it started failing with some weird recursive error)
54+
#- repo: https://github.com/jendrikseipp/vulture
55+
# rev: v2.6
56+
# hooks:
57+
# - id: vulture
58+
# args: [--min-confidence=100, --sort-by-size, .]

oml/models/vit_clip/extractor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
filter_state_dict,
1010
patch_device_and_float,
1111
remove_criterion_in_state_dict,
12+
remove_prefix_from_state_dict,
1213
)
1314
from oml.models.vit_clip.external.model import VisionTransformer
1415
from oml.utils.io import download_checkpoint
@@ -142,6 +143,8 @@ def __init__(
142143
self.normalize = normalise_features
143144
self.visual = self.constructors[arch]()
144145

146+
self.input_size = int(arch.split("_")[-1])
147+
145148
if weights is None:
146149
return
147150
if weights in self.pretrained_models:
@@ -159,11 +162,16 @@ def __init__(
159162
state_dict = torch.load(Path(weights), map_location="cpu")
160163
state_dict = state_dict.get("state_dict", state_dict)
161164
state_dict = remove_criterion_in_state_dict(state_dict)
165+
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="class_embedding")
162166
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())
163167

164168
self.visual.load_state_dict(state_dict=state_dict, strict=True)
165169

166170
def forward(self, x: torch.Tensor) -> torch.Tensor:
171+
assert (x.shape[-2] == self.input_size) and (
172+
x.shape[-1] == self.input_size
173+
), f"The model expects input images to be resized to {self.input_size}x{self.input_size}"
174+
167175
res = self.visual.forward(x)
168176
if self.normalize:
169177
res = res / torch.linalg.norm(res, 2, dim=1, keepdim=True).detach()

tests/test_oml/test_models/test_models_creation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
vit_args = {"normalise_features": False, "use_multi_scale": False, "arch": "vits16"}
2121

2222

23+
# todo: add another test where Lightning saves the model
2324
@pytest.mark.parametrize(
2425
"constructor,args",
2526
[

tests/test_runs/test_pipelines/configs/train.yaml

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,32 @@ cache_size: 10
1414
transforms_train:
1515
name: augs_torch
1616
args:
17-
im_size: 64
17+
im_size: 224
1818

1919
transforms_val:
2020
name: norm_resize_torch
2121
args:
22-
im_size: 64
22+
im_size: 224
2323

2424
criterion:
2525
name: arcface
2626
args:
2727
smoothing_epsilon: 0
2828
m: 0.4
2929
s: 64
30-
in_features: 384
30+
in_features: 512
3131
num_classes: 4
3232

3333
defaults:
3434
- optimizer: sgd
3535
- sampler: balance
3636

3737
extractor:
38-
name: extractor_with_mlp
38+
name: vit_clip
3939
args:
40-
mlp_features: [384]
40+
arch: vitb16_224
4141
weights: null
42-
extractor:
43-
name: vit
44-
args:
45-
normalise_features: False
46-
use_multi_scale: False
47-
weights: null
48-
arch: vits16
42+
normalise_features: False
4943

5044
scheduling:
5145
scheduler_interval: epoch

tests/test_runs/test_pipelines/configs/validate.yaml

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dataframe_name: df.csv
77
transforms_val:
88
name: norm_resize_torch
99
args:
10-
im_size: 48
10+
im_size: 224
1111

1212
logs_root: logs
1313

@@ -47,18 +47,11 @@ metric_args:
4747
visualize_only_overall_category: True
4848

4949
extractor:
50-
name: extractor_with_mlp
50+
name: vit_clip
5151
args:
52-
mlp_features: [384]
52+
arch: vitb16_224
5353
weights: checkpoints/best.ckpt
54-
extractor:
55-
name: vit
56-
args:
57-
normalise_features: False
58-
use_multi_scale: False
59-
weights: null
60-
arch: vits16
61-
54+
normalise_features: False
6255

6356
hydra:
6457
run:

tests/test_runs/test_pipelines/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def run(file: str, accelerator: str, devices: int, need_rm_logs: bool = True) ->
4545
@pytest.mark.parametrize("accelerator, devices", accelerator_devices_pairs())
4646
def test_train_and_validate(accelerator: str, devices: int) -> None:
4747
run("train.py", accelerator, devices, need_rm_logs=False)
48-
# it takes checpoints from the train stage
48+
# it takes checkpoints from the train stage
4949
run("validate.py", accelerator, devices, need_rm_logs=False)
5050

5151
for file in ["train.py", "validate.py"]:

0 commit comments

Comments
 (0)