Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored and alex committed Feb 2, 2024
1 parent 4f43b5f commit 5146a27
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
10 changes: 8 additions & 2 deletions oml/models/vit_clip/extractor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
from pathlib import Path
from typing import Any, Dict, Iterable, Optional

import torch

from oml.interfaces.models import IExtractor
from oml.models.utils import (
TStateDict,
filter_state_dict,
patch_device_and_float,
remove_criterion_in_state_dict,
)
from oml.models.utils import remove_prefix_from_state_dict
from oml.models.vit_clip.external.model import VisionTransformer
from oml.utils.io import download_checkpoint

Expand Down Expand Up @@ -142,6 +142,8 @@ def __init__(
self.normalize = normalise_features
self.visual = self.constructors[arch]()

self.input_size = int(arch.split("_")[-1])

if weights is None:
return
if weights in self.pretrained_models:
Expand All @@ -159,11 +161,15 @@ def __init__(
state_dict = torch.load(Path(weights), map_location="cpu")
state_dict = state_dict.get("state_dict", state_dict)
state_dict = remove_criterion_in_state_dict(state_dict)
state_dict = remove_prefix_from_state_dict(state_dict, trial_key="class_embedding")
state_dict = take_visual_part_of_vit_clip(state_dict, needed_keys=self.visual.state_dict().keys())

self.visual.load_state_dict(state_dict=state_dict, strict=True)

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert (x.shape[-2] == self.input_size) and (x.shape[-1] == self.input_size), \
f"The model expects input images to be resized to {self.input_size}x{self.input_size}"

res = self.visual.forward(x)
if self.normalize:
res = res / torch.linalg.norm(res, 2, dim=1, keepdim=True).detach()
Expand Down
1 change: 1 addition & 0 deletions tests/test_oml/test_models/test_models_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
vit_args = {"normalise_features": False, "use_multi_scale": False, "arch": "vits16"}


# todo: add another test where Lightning saves the model
@pytest.mark.parametrize(
"constructor,args",
[
Expand Down
18 changes: 6 additions & 12 deletions tests/test_runs/test_pipelines/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,32 @@ cache_size: 10
transforms_train:
name: augs_torch
args:
im_size: 64
im_size: 224

transforms_val:
name: norm_resize_torch
args:
im_size: 64
im_size: 224

criterion:
name: arcface
args:
smoothing_epsilon: 0
m: 0.4
s: 64
in_features: 384
in_features: 512
num_classes: 4

defaults:
- optimizer: sgd
- sampler: balance

extractor:
name: extractor_with_mlp
name: vit_clip
args:
mlp_features: [384]
arch: vitb16_224
weights: null
extractor:
name: vit
args:
normalise_features: False
use_multi_scale: False
weights: null
arch: vits16
normalise_features: False

scheduling:
scheduler_interval: epoch
Expand Down

0 comments on commit 5146a27

Please sign in to comment.