Skip to content

Commit 8bfcd08

Browse files
authored
Merge pull request #99 from ayasyrev/refactor_modelcfg
Refactor modelcfg
2 parents 2a97347 + 6dbecde commit 8bfcd08

File tree

9 files changed

+107
-107
lines changed

9 files changed

+107
-107
lines changed

noxfile_conda.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
@nox.session(python=["3.9", "3.10", "3.11"], venv_backend="mamba")
55
def conda_tests(session: nox.Session) -> None:
66
args = session.posargs or ["--cov"]
7-
# session.install("pytest", "pytest-cov")
87
session.conda_install("pytest", "pytest-cov")
98
session.conda_install("pytorch")
109
session.conda_install("pydantic")

src/model_constructor/blocks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Optional
22

33
import torch
44
from torch import nn
@@ -22,10 +22,10 @@ def __init__(
2222
bn_1st: bool = True,
2323
groups: int = 1,
2424
dw: bool = False,
25-
div_groups: Union[None, int] = None,
26-
pool: Union[Callable[[], nn.Module], None] = None,
27-
se: Union[nn.Module, None] = None,
28-
sa: Union[nn.Module, None] = None,
25+
div_groups: Optional[int] = None,
26+
pool: Optional[Callable[[], nn.Module]] = None,
27+
se: Optional[nn.Module] = None,
28+
sa: Optional[nn.Module] = None,
2929
):
3030
super().__init__()
3131
# pool defined at ModelConstructor.
@@ -107,10 +107,10 @@ def __init__(
107107
bn_1st: bool = True,
108108
groups: int = 1,
109109
dw: bool = False,
110-
div_groups: Union[None, int] = None,
111-
pool: Union[Callable[[], nn.Module], None] = None,
112-
se: Union[nn.Module, None] = None,
113-
sa: Union[nn.Module, None] = None,
110+
div_groups: Optional[int] = None,
111+
pool: Optional[Callable[[], nn.Module]] = None,
112+
se: Optional[nn.Module] = None,
113+
sa: Optional[nn.Module] = None,
114114
):
115115
super().__init__()
116116
# pool defined at ModelConstructor.

src/model_constructor/convmixer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from collections import OrderedDict
66
from typing import Callable, List, Optional, Union
77

8+
import torch
89
import torch.nn as nn
9-
from torch import TensorType
1010

1111

1212
class Residual(nn.Module):
13-
def __init__(self, fn: Callable[[TensorType], TensorType]):
13+
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
1414
super().__init__()
1515
self.fn = fn
1616

17-
def forward(self, x: TensorType) -> TensorType:
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1818
return self.fn(x) + x
1919

2020

src/model_constructor/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Iterable, Optional
3+
from typing import Iterable, Optional, Union
44
from pydantic import BaseModel
55

66
from torch import nn
77

88

99
ListStrMod = list[tuple[str, nn.Module]]
10+
ModSeq = Union[nn.Module, nn.Sequential]
1011

1112

1213
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:

src/model_constructor/layers.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,14 @@ class SEModule(nn.Module):
258258

259259
def __init__(
260260
self,
261-
channels,
262-
reduction=16,
263-
rd_channels=None,
264-
rd_max=False,
265-
se_layer=nn.Linear,
266-
act_fn=nn.ReLU(inplace=True),
267-
use_bias=True,
268-
gate=nn.Sigmoid,
261+
channels: int,
262+
reduction: int = 16,
263+
rd_channels: Optional[int] = None,
264+
rd_max: bool = False,
265+
se_layer: type[nn.Module] = nn.Linear,
266+
act_fn: nn.Module = nn.ReLU(inplace=True),
267+
use_bias: bool = True,
268+
gate: type[nn.Module] = nn.Sigmoid,
269269
):
270270
super().__init__()
271271
reducted = max(channels // reduction, 1) # preserve zero-element tensors
@@ -286,7 +286,7 @@ def __init__(
286286
)
287287
)
288288

289-
def forward(self, x):
289+
def forward(self, x: torch.Tensor) -> torch.Tensor:
290290
bs, c, _, _ = x.shape
291291
y = self.squeeze(x).view(bs, c)
292292
y = self.excitation(y).view(bs, c, 1, 1)
@@ -298,14 +298,14 @@ class SEModuleConv(nn.Module):
298298

299299
def __init__(
300300
self,
301-
channels,
302-
reduction=16,
303-
rd_channels=None,
304-
rd_max=False,
305-
se_layer=nn.Conv2d,
306-
act_fn=nn.ReLU(inplace=True),
307-
use_bias=True,
308-
gate=nn.Sigmoid,
301+
channels: int,
302+
reduction: int = 16,
303+
rd_channels: Optional[int] = None,
304+
rd_max: bool = False,
305+
se_layer: type[nn.Module] = nn.Conv2d,
306+
act_fn: nn.Module = nn.ReLU(inplace=True),
307+
use_bias: bool = True,
308+
gate: type[nn.Module] = nn.Sigmoid,
309309
):
310310
super().__init__()
311311
# rd_channels = math.ceil(channels//reduction/8)*8
@@ -327,7 +327,7 @@ def __init__(
327327
)
328328
)
329329

330-
def forward(self, x):
330+
def forward(self, x: torch.Tensor) -> torch.Tensor:
331331
y = self.squeeze(x)
332332
y = self.excitation(y)
333333
return x * y.expand_as(x)

src/model_constructor/model_constructor.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77

88
from .blocks import BasicBlock, BottleneckBlock
9-
from .helpers import Cfg, ListStrMod, init_cnn, nn_seq
9+
from .helpers import Cfg, ListStrMod, ModSeq, init_cnn, nn_seq
1010
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1111

1212
__all__ = [
@@ -33,20 +33,45 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3333
expansion: int = 1
3434
groups: int = 1
3535
dw: bool = False
36-
div_groups: Union[int, None] = None
36+
div_groups: Optional[int] = None
3737
sa: Union[bool, type[nn.Module]] = False
3838
se: Union[bool, type[nn.Module]] = False
39-
se_module: Union[bool, None] = None
40-
se_reduction: Union[int, None] = None
39+
se_module: Optional[bool] = None
40+
se_reduction: Optional[int] = None
4141
bn_1st: bool = True
4242
zero_bn: bool = True
4343
stem_stride_on: int = 0
4444
stem_sizes: list[int] = [64]
45-
stem_pool: Union[Callable[[], nn.Module], None] = partial(
45+
stem_pool: Optional[Callable[[], nn.Module]] = partial(
4646
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
4747
)
4848
stem_bn_end: bool = False
4949

50+
@field_validator("se")
51+
def set_se( # pylint: disable=no-self-argument
52+
cls, value: Union[bool, type[nn.Module]]
53+
) -> Union[bool, type[nn.Module]]:
54+
if value:
55+
if isinstance(value, (int, bool)):
56+
return SEModule
57+
return value
58+
59+
@field_validator("sa")
60+
def set_sa( # pylint: disable=no-self-argument
61+
cls, value: Union[bool, type[nn.Module]]
62+
) -> Union[bool, type[nn.Module]]:
63+
if value:
64+
if isinstance(value, (int, bool)):
65+
return SimpleSelfAttention # default: ks=1, sym=sym
66+
return value
67+
68+
@field_validator("se_module", "se_reduction") # pragma: no cover
69+
def deprecation_warning( # pylint: disable=no-self-argument
70+
cls, value: Union[bool, int, None]
71+
) -> Union[bool, int, None]:
72+
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.")
73+
return value
74+
5075
def __repr__(self) -> str:
5176
se_repr = self.se.__name__ if self.se else "False" # type: ignore
5277
model_name = self.name or self.__class__.__name__
@@ -61,7 +86,7 @@ def __repr__(self) -> str:
6186
)
6287

6388

64-
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
89+
def make_stem(cfg: ModelCfg) -> nn.Sequential:
6590
"""Create Resnet stem."""
6691
stem: ListStrMod = [
6792
(
@@ -116,15 +141,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
116141
)
117142

118143

119-
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
144+
def make_body(cfg: ModelCfg) -> nn.Sequential:
120145
"""Create model body."""
121146
return nn_seq(
122147
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
123148
for layer_num in range(len(cfg.layers))
124149
)
125150

126151

127-
def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
152+
def make_head(cfg: ModelCfg) -> nn.Sequential:
128153
"""Create head."""
129154
head = [
130155
("pool", nn.AdaptiveAvgPool2d(1)),
@@ -138,35 +163,10 @@ class ModelConstructor(ModelCfg):
138163
"""Model constructor. As default - resnet18"""
139164

140165
init_cnn: Callable[[nn.Module], None] = init_cnn
141-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem # type: ignore
142-
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer # type: ignore
143-
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body # type: ignore
144-
make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head # type: ignore
145-
146-
@field_validator("se")
147-
def set_se( # pylint: disable=no-self-argument
148-
cls, value: Union[bool, type[nn.Module]]
149-
) -> Union[bool, type[nn.Module]]:
150-
if value:
151-
if isinstance(value, (int, bool)):
152-
return SEModule
153-
return value
154-
155-
@field_validator("sa")
156-
def set_sa( # pylint: disable=no-self-argument
157-
cls, value: Union[bool, type[nn.Module]]
158-
) -> Union[bool, type[nn.Module]]:
159-
if value:
160-
if isinstance(value, (int, bool)):
161-
return SimpleSelfAttention # default: ks=1, sym=sym
162-
return value
163-
164-
@field_validator("se_module", "se_reduction") # pragma: no cover
165-
def deprecation_warning( # pylint: disable=no-self-argument
166-
cls, value: Union[bool, int, None]
167-
) -> Union[bool, int, None]:
168-
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.")
169-
return value
166+
make_stem: Callable[[ModelCfg], ModSeq] = make_stem
167+
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer
168+
make_body: Callable[[ModelCfg], ModSeq] = make_body
169+
make_head: Callable[[ModelCfg], ModSeq] = make_head
170170

171171
@property
172172
def stem(self):
@@ -186,7 +186,7 @@ def from_cfg(cls, cfg: ModelCfg):
186186

187187
@classmethod
188188
def create_model(
189-
cls, cfg: Union[ModelCfg, None] = None, **kwargs: dict[str, Any]
189+
cls, cfg: Optional[ModelCfg] = None, **kwargs: dict[str, Any]
190190
) -> nn.Sequential:
191191
if cfg:
192192
return cls(**cfg.model_dump())()

src/model_constructor/universal_blocks.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Optional
22

33
import torch
44
from torch import nn
55

6-
from .helpers import nn_seq
6+
from .helpers import ModSeq, nn_seq
77
from .layers import ConvBnAct, get_act
88
from .model_constructor import ListStrMod, ModelCfg, ModelConstructor
99

@@ -32,10 +32,10 @@ def __init__(
3232
bn_1st: bool = True,
3333
groups: int = 1,
3434
dw: bool = False,
35-
div_groups: Union[None, int] = None,
36-
pool: Union[Callable[[], nn.Module], None] = None,
37-
se: Union[nn.Module, None] = None,
38-
sa: Union[nn.Module, None] = None,
35+
div_groups: Optional[int] = None,
36+
pool: Optional[Callable[[], nn.Module]] = None,
37+
se: Optional[nn.Module] = None,
38+
sa: Optional[nn.Module] = None,
3939
):
4040
super().__init__()
4141
# pool defined at ModelConstructor.
@@ -134,7 +134,7 @@ def __init__(
134134
self.id_conv = None
135135
self.act_fn = get_act(act_fn)
136136

137-
def forward(self, x: torch.Tensor): # type: ignore
137+
def forward(self, x: torch.Tensor) -> torch.Tensor:
138138
identity = self.id_conv(x) if self.id_conv is not None else x
139139
return self.act_fn(self.convs(x) + identity)
140140

@@ -156,10 +156,10 @@ def __init__(
156156
bn_1st: bool = True,
157157
groups: int = 1,
158158
dw: bool = False,
159-
div_groups: Union[None, int] = None,
160-
pool: Union[Callable[[], nn.Module], None] = None,
161-
se: Union[type[nn.Module], None] = None,
162-
sa: Union[type[nn.Module], None] = None,
159+
div_groups: Optional[int] = None,
160+
pool: Optional[Callable[[], nn.Module]] = None,
161+
se: Optional[type[nn.Module]] = None,
162+
sa: Optional[type[nn.Module]] = None,
163163
):
164164
super().__init__()
165165
# pool defined at ModelConstructor.
@@ -255,14 +255,14 @@ def __init__(
255255
self.id_conv = None
256256
self.merge = get_act(act_fn)
257257

258-
def forward(self, x: torch.Tensor): # type: ignore
258+
def forward(self, x: torch.Tensor) -> torch.Tensor:
259259
if self.reduce:
260260
x = self.reduce(x)
261261
identity = self.id_conv(x) if self.id_conv is not None else x
262262
return self.merge(self.convs(x) + identity)
263263

264264

265-
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
265+
def make_stem(cfg: ModelCfg) -> nn.Sequential:
266266
"""Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
267267
len_stem = len(cfg.stem_sizes)
268268
stem: list[tuple[str, nn.Module]] = [
@@ -286,7 +286,7 @@ def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
286286
return nn_seq(stem)
287287

288288

289-
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
289+
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
290290
"""Create layer (stage)"""
291291
# if no pool on stem - stride = 2 for first layer block in body
292292
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
@@ -316,15 +316,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
316316
)
317317

318318

319-
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
319+
def make_body(cfg: ModelCfg) -> nn.Sequential:
320320
"""Create model body."""
321321
return nn_seq(
322322
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
323323
for layer_num in range(len(cfg.layers))
324324
)
325325

326326

327-
def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
327+
def make_head(cfg: ModelCfg) -> nn.Sequential:
328328
"""Create head."""
329329
head = [
330330
("pool", nn.AdaptiveAvgPool2d(1)),
@@ -337,10 +337,10 @@ def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
337337
class XResNet(ModelConstructor):
338338
"""Base Xresnet constructor."""
339339

340-
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem
341-
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer
342-
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body
343-
make_head: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_head
340+
make_stem: Callable[[ModelCfg], ModSeq] = make_stem
341+
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer
342+
make_body: Callable[[ModelCfg], ModSeq] = make_body
343+
make_head: Callable[[ModelCfg], ModSeq] = make_head
344344
block: type[nn.Module] = XResBlock
345345

346346

0 commit comments

Comments
 (0)