Skip to content

Commit 572f050

Browse files
committed
Swin and FocalNet weights on HF hub. Add model deprecation functionality w/ some registry tweaks.
1 parent 2fc5ac3 commit 572f050

22 files changed

+546
-435
lines changed

timm/models/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@
7474
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
7575
register_notrace_module, is_notrace_module, get_notrace_modules, \
7676
register_notrace_function, is_notrace_function, get_notrace_functions
77-
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
77+
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
7878
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
7979
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
8080
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
81-
from ._pretrained import PretrainedCfg, DefaultCfg, \
82-
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
81+
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
8382
from ._prune import adapt_model_from_string
84-
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
85-
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
83+
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
84+
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
85+
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

timm/models/_factory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from urllib.parse import urlsplit
44

55
from timm.layers import set_layer_config
6-
from ._pretrained import PretrainedCfg, split_model_name_tag
76
from ._helpers import load_checkpoint
87
from ._hub import load_model_config_from_hf
9-
from ._registry import is_model, model_entrypoint
8+
from ._pretrained import PretrainedCfg
9+
from ._registry import is_model, model_entrypoint, split_model_name_tag
1010

1111

1212
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']

timm/models/_helpers.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
from collections import OrderedDict
8+
from typing import Any, Callable, Dict, Optional, Union
89

910
import torch
1011
try:
@@ -13,30 +14,32 @@
1314
except ImportError:
1415
_has_safetensors = False
1516

16-
import timm.models._builder
17-
1817
_logger = logging.getLogger(__name__)
1918

20-
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
19+
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
2120

2221

23-
def clean_state_dict(state_dict):
22+
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
2423
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
25-
cleaned_state_dict = OrderedDict()
24+
cleaned_state_dict = {}
2625
for k, v in state_dict.items():
2726
name = k[7:] if k.startswith('module.') else k
2827
cleaned_state_dict[name] = v
2928
return cleaned_state_dict
3029

3130

32-
def load_state_dict(checkpoint_path, use_ema=True):
31+
def load_state_dict(
32+
checkpoint_path: str,
33+
use_ema: bool = True,
34+
device: Union[str, torch.device] = 'cpu',
35+
) -> Dict[str, Any]:
3336
if checkpoint_path and os.path.isfile(checkpoint_path):
3437
# Check if safetensors or not and load weights accordingly
3538
if str(checkpoint_path).endswith(".safetensors"):
3639
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
37-
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
40+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
3841
else:
39-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
42+
checkpoint = torch.load(checkpoint_path, map_location=device)
4043

4144
state_dict_key = ''
4245
if isinstance(checkpoint, dict):
@@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
5659
raise FileNotFoundError()
5760

5861

59-
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
62+
def load_checkpoint(
63+
model: torch.nn.Module,
64+
checkpoint_path: str,
65+
use_ema: bool = True,
66+
device: Union[str, torch.device] = 'cpu',
67+
strict: bool = True,
68+
remap: bool = False,
69+
filter_fn: Optional[Callable] = None,
70+
):
6071
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
6172
# numpy checkpoint, try to load via model specific load_pretrained fn
6273
if hasattr(model, 'load_pretrained'):
63-
timm.models._model_builder.load_pretrained(checkpoint_path)
74+
model.load_pretrained(checkpoint_path)
6475
else:
6576
raise NotImplementedError('Model cannot load numpy checkpoint')
6677
return
67-
state_dict = load_state_dict(checkpoint_path, use_ema)
78+
79+
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
6880
if remap:
69-
state_dict = remap_checkpoint(model, state_dict)
81+
state_dict = remap_state_dict(state_dict, model)
82+
elif filter_fn:
83+
state_dict = filter_fn(state_dict, model)
7084
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
7185
return incompatible_keys
7286

7387

74-
def remap_checkpoint(model, state_dict, allow_reshape=True):
88+
def remap_state_dict(
89+
state_dict: Dict[str, Any],
90+
model: torch.nn.Module,
91+
allow_reshape: bool = True
92+
):
7593
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
7694
This assumes models (and originating state dict) were created with params registered in same order.
7795
"""
@@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
87105
return out_dict
88106

89107

90-
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
108+
def resume_checkpoint(
109+
model: torch.nn.Module,
110+
checkpoint_path: str,
111+
optimizer: torch.optim.Optimizer = None,
112+
loss_scaler: Any = None,
113+
log_info: bool = True,
114+
):
91115
resume_epoch = None
92116
if os.path.isfile(checkpoint_path):
93117
checkpoint = torch.load(checkpoint_path, map_location='cpu')

timm/models/_manipulate.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from collections import defaultdict
55
from itertools import chain
6-
from typing import Callable, Union, Dict
6+
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
77

88
import torch
99
from torch import nn as nn
@@ -13,15 +13,20 @@
1313
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
1414

1515

16-
def model_parameters(model, exclude_head=False):
16+
def model_parameters(model: nn.Module, exclude_head: bool = False):
1717
if exclude_head:
1818
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
1919
return [p for p in model.parameters()][:-2]
2020
else:
2121
return model.parameters()
2222

2323

24-
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
24+
def named_apply(
25+
fn: Callable,
26+
module: nn.Module, name='',
27+
depth_first: bool = True,
28+
include_root: bool = False,
29+
) -> nn.Module:
2530
if not depth_first and include_root:
2631
fn(module=module, name=name)
2732
for child_name, child_module in module.named_children():
@@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
3237
return module
3338

3439

35-
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
40+
def named_modules(
41+
module: nn.Module,
42+
name: str = '',
43+
depth_first: bool = True,
44+
include_root: bool = False,
45+
):
3646
if not depth_first and include_root:
3747
yield name, module
3848
for child_name, child_module in module.named_children():
@@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
4353
yield name, module
4454

4555

46-
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
56+
def named_modules_with_params(
57+
module: nn.Module,
58+
name: str = '',
59+
depth_first: bool = True,
60+
include_root: bool = False,
61+
):
4762
if module._parameters and not depth_first and include_root:
4863
yield name, module
4964
for child_name, child_module in module.named_children():
@@ -58,9 +73,9 @@ def named_modules_with_params(module: nn.Module, name='', depth_first=True, incl
5873

5974

6075
def group_with_matcher(
61-
named_objects,
76+
named_objects: Iterator[Tuple[str, Any]],
6277
group_matcher: Union[Dict, Callable],
63-
output_values: bool = False,
78+
return_values: bool = False,
6479
reverse: bool = False
6580
):
6681
if isinstance(group_matcher, dict):
@@ -96,7 +111,7 @@ def _get_grouping(name):
96111
# map layers into groups via ordinals (ints or tuples of ints) from matcher
97112
grouping = defaultdict(list)
98113
for k, v in named_objects:
99-
grouping[_get_grouping(k)].append(v if output_values else k)
114+
grouping[_get_grouping(k)].append(v if return_values else k)
100115

101116
# remap to integers
102117
layer_id_to_param = defaultdict(list)
@@ -107,7 +122,7 @@ def _get_grouping(name):
107122
layer_id_to_param[lid].extend(grouping[k])
108123

109124
if reverse:
110-
assert not output_values, "reverse mapping only sensible for name output"
125+
assert not return_values, "reverse mapping only sensible for name output"
111126
# output reverse mapping
112127
param_to_layer_id = {}
113128
for lid, lm in layer_id_to_param.items():
@@ -121,24 +136,29 @@ def _get_grouping(name):
121136
def group_parameters(
122137
module: nn.Module,
123138
group_matcher,
124-
output_values=False,
125-
reverse=False,
139+
return_values: bool = False,
140+
reverse: bool = False,
126141
):
127142
return group_with_matcher(
128-
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
143+
module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
129144

130145

131146
def group_modules(
132147
module: nn.Module,
133148
group_matcher,
134-
output_values=False,
135-
reverse=False,
149+
return_values: bool = False,
150+
reverse: bool = False,
136151
):
137152
return group_with_matcher(
138-
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
153+
named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
139154

140155

141-
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
156+
def flatten_modules(
157+
named_modules: Iterator[Tuple[str, nn.Module]],
158+
depth: int = 1,
159+
prefix: Union[str, Tuple[str, ...]] = '',
160+
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
161+
):
142162
prefix_is_tuple = isinstance(prefix, tuple)
143163
if isinstance(module_types, str):
144164
if module_types == 'container':

timm/models/_pretrained.py

+1-39
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Deque, Dict, Tuple, Optional, Union
55

66

7-
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
7+
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
88

99

1010
@dataclass
@@ -91,41 +91,3 @@ def default(self):
9191
def default_with_tag(self):
9292
tag = self.tags[0]
9393
return tag, self.cfgs[tag]
94-
95-
96-
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
97-
model_name, *tag_list = model_name.split('.', 1)
98-
tag = tag_list[0] if tag_list else no_tag
99-
return model_name, tag
100-
101-
102-
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
103-
out = defaultdict(DefaultCfg)
104-
default_set = set() # no tag and tags ending with * are prioritized as default
105-
106-
for k, v in cfgs.items():
107-
if isinstance(v, dict):
108-
v = PretrainedCfg(**v)
109-
has_weights = v.has_weights
110-
111-
model, tag = split_model_name_tag(k)
112-
is_default_set = model in default_set
113-
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
114-
tag = tag.strip('*')
115-
116-
default_cfg = out[model]
117-
118-
if priority:
119-
default_cfg.tags.appendleft(tag)
120-
default_set.add(model)
121-
elif has_weights and not default_cfg.is_pretrained:
122-
default_cfg.tags.appendleft(tag)
123-
else:
124-
default_cfg.tags.append(tag)
125-
126-
if has_weights:
127-
default_cfg.is_pretrained = True
128-
129-
default_cfg.cfgs[tag] = v
130-
131-
return out

0 commit comments

Comments
 (0)