-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathload_model.py
71 lines (54 loc) · 1.99 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
""" Load a model from its checkpoint directory """
import glob
import os
import hydra
import torch
from pyaml_env import parse_config
from pvnet.models.ensemble import Ensemble
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel
def get_model_from_checkpoints(
checkpoint_dir_paths: list[str],
val_best: bool = True,
):
"""Load a model from its checkpoint directory"""
is_ensemble = len(checkpoint_dir_paths) > 1
model_configs = []
models = []
data_configs = []
for path in checkpoint_dir_paths:
# Load the model
model_config = parse_config(f"{path}/model_config.yaml")
model = hydra.utils.instantiate(model_config)
if val_best:
# Only one epoch (best) saved per model
files = glob.glob(f"{path}/epoch*.ckpt")
if len(files) != 1:
raise ValueError(
f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one."
)
checkpoint = torch.load(files[0], map_location="cpu")
else:
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu")
model.load_state_dict(state_dict=checkpoint["state_dict"])
if isinstance(model, UMTModel):
model, model_config = model.convert_to_multimodal_model(model_config)
# Check for data config
data_config = f"{path}/data_config.yaml"
if os.path.isfile(data_config):
data_configs.append(data_config)
else:
data_configs.append(None)
model_configs.append(model_config)
models.append(model)
if is_ensemble:
model_config = {
"_target_": "pvnet.models.ensemble.Ensemble",
"model_list": model_configs,
}
model = Ensemble(model_list=models)
data_config = data_configs[0]
else:
model_config = model_configs[0]
model = models[0]
data_config = data_configs[0]
return model, model_config, data_config