Skip to content

Commit 74b427e

Browse files
authored
Merge pull request #110 from ayasyrev/test_typing
Test typing
2 parents e077402 + cad905e commit 74b427e

13 files changed

+96
-76
lines changed

noxfile_conda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import nox
22

33

4-
@nox.session(python=["3.9", "3.10", "3.11"], venv_backend="mamba")
4+
@nox.session(python=["3.8", "3.9", "3.10", "3.11"], venv_backend="mamba")
55
def conda_tests(session: nox.Session) -> None:
66
args = session.posargs or ["--cov"]
77
session.conda_install("pytest", "pytest-cov")

src/model_constructor/blocks.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional
1+
from typing import Callable, Optional, Type
22

33
import torch
44
from torch import nn
@@ -16,8 +16,8 @@ def __init__(
1616
in_channels: int,
1717
out_channels: int,
1818
stride: int = 1,
19-
conv_layer: type[ConvBnAct] = ConvBnAct,
20-
act_fn: type[nn.Module] = nn.ReLU,
19+
conv_layer: Type[ConvBnAct] = ConvBnAct,
20+
act_fn: Type[nn.Module] = nn.ReLU,
2121
zero_bn: bool = True,
2222
bn_1st: bool = True,
2323
groups: int = 1,
@@ -101,8 +101,8 @@ def __init__(
101101
out_channels: int,
102102
stride: int = 1,
103103
expansion: int = 4,
104-
conv_layer: type[ConvBnAct] = ConvBnAct,
105-
act_fn: type[nn.Module] = nn.ReLU,
104+
conv_layer: Type[ConvBnAct] = ConvBnAct,
105+
act_fn: Type[nn.Module] = nn.ReLU,
106106
zero_bn: bool = True,
107107
bn_1st: bool = True,
108108
groups: int = 1,

src/model_constructor/convmixer.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable, List, Optional, Union
6+
from typing import Callable, Optional, Tuple, Union
77

88
import torch
99
import torch.nn as nn
1010

11+
from .helpers import ListStrMod
12+
1113

1214
class Residual(nn.Module):
1315
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
@@ -59,15 +61,15 @@ def __init__(
5961
self,
6062
in_channels: int,
6163
out_channels: int,
62-
kernel_size: Union[int, tuple[int, int]],
64+
kernel_size: Union[int, Tuple[int, int]],
6365
stride: int = 1,
6466
act_fn: nn.Module = nn.GELU(),
6567
padding: Union[int, str] = 0,
6668
groups: int = 1,
6769
bn_1st: bool = False,
6870
pre_act: bool = False,
6971
):
70-
conv_layer: List[tuple[str, nn.Module]] = [
72+
conv_layer: ListStrMod = [
7173
(
7274
"conv",
7375
nn.Conv2d(
@@ -80,7 +82,7 @@ def __init__(
8082
),
8183
)
8284
]
83-
act_bn: List[tuple[str, nn.Module]] = [
85+
act_bn: ListStrMod = [
8486
("act_fn", act_fn),
8587
("bn", nn.BatchNorm2d(out_channels)),
8688
]

src/model_constructor/helpers.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import importlib
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Iterable, Optional, Union
4+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
55

66
from pydantic import BaseModel
77
from torch import nn
88

9-
ListStrMod = list[tuple[str, nn.Module]]
9+
ListStrMod = List[Tuple[str, nn.Module]]
1010
ModSeq = Union[nn.Module, nn.Sequential]
1111

1212

13-
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
13+
def nn_seq(list_of_tuples: Iterable[Tuple[str, nn.Module]]) -> nn.Sequential:
1414
"""return nn.Sequential from OrderedDict from list of tuples"""
1515
return nn.Sequential(OrderedDict(list_of_tuples))
1616

@@ -86,22 +86,22 @@ def _get_str_value(self, field: str) -> str:
8686
def __repr__(self) -> str:
8787
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
8888

89-
def __repr_args__(self) -> list[tuple[str, str]]:
89+
def __repr_args__(self) -> List[Tuple[str, str]]:
9090
return [
9191
(field, str_value)
9292
for field in self.model_fields
9393
if (str_value := self._get_str_value(field))
9494
]
9595

96-
def __repr_set_fields__(self) -> list[str]:
96+
def __repr_set_fields__(self) -> List[str]:
9797
"""Return list repr for fields set at init"""
9898
return [
9999
f"{field}: {self._get_str_value(field)}"
100100
for field in self.model_fields_set # pylint: disable=E1133
101101
if field != "name"
102102
]
103103

104-
def __repr_changed_fields__(self) -> list[str]:
104+
def __repr_changed_fields__(self) -> List[str]:
105105
"""Return list repr for changed fields"""
106106
return [
107107
f"{field}: {self._get_str_value(field)}"
@@ -110,7 +110,7 @@ def __repr_changed_fields__(self) -> list[str]:
110110
]
111111

112112
@property
113-
def changed_fields(self) -> dict[str, Any]:
113+
def changed_fields(self) -> Dict[str, Any]:
114114
# return "\n".join(self.__repr_changed_fields__())
115115
return {
116116
field: self._get_str_value(field)

src/model_constructor/layers.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections import OrderedDict
2-
from typing import List, Optional, Type, Union
2+
from typing import Optional, Type, Union
33

44
import torch
55
from torch import nn
66
from torch.nn.utils.spectral_norm import spectral_norm
77

8+
from .helpers import ListStrMod
9+
810
__all__ = [
911
"Flatten",
1012
"noop",
@@ -72,7 +74,7 @@ def __init__(
7274
):
7375
if padding is None:
7476
padding = kernel_size // 2
75-
layers: List[tuple[str, nn.Module]] = [
77+
layers: ListStrMod = [
7678
(
7779
"conv",
7880
self.convolution_module(
@@ -262,10 +264,10 @@ def __init__(
262264
reduction: int = 16,
263265
rd_channels: Optional[int] = None,
264266
rd_max: bool = False,
265-
se_layer: type[nn.Module] = nn.Linear,
267+
se_layer: Type[nn.Module] = nn.Linear,
266268
act_fn: nn.Module = nn.ReLU(inplace=True),
267269
use_bias: bool = True,
268-
gate: type[nn.Module] = nn.Sigmoid,
270+
gate: Type[nn.Module] = nn.Sigmoid,
269271
):
270272
super().__init__()
271273
reducted = max(channels // reduction, 1) # preserve zero-element tensors
@@ -302,10 +304,10 @@ def __init__(
302304
reduction: int = 16,
303305
rd_channels: Optional[int] = None,
304306
rd_max: bool = False,
305-
se_layer: type[nn.Module] = nn.Conv2d,
307+
se_layer: Type[nn.Module] = nn.Conv2d,
306308
act_fn: nn.Module = nn.ReLU(inplace=True),
307309
use_bias: bool = True,
308-
gate: type[nn.Module] = nn.Sigmoid,
310+
gate: Type[nn.Module] = nn.Sigmoid,
309311
):
310312
super().__init__()
311313
# rd_channels = math.ceil(channels//reduction/8)*8

src/model_constructor/model_constructor.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Dict, List, Optional, Union, Type
44

55
from pydantic import field_validator
66
from pydantic_core.core_schema import FieldValidationInfo
77
from torch import nn
88

99
from .blocks import BasicBlock, BottleneckBlock
10-
from .helpers import (Cfg, ListStrMod, ModSeq, init_cnn, instantiate_module,
11-
is_module, nn_seq)
10+
from .helpers import (
11+
Cfg,
12+
ListStrMod,
13+
ModSeq,
14+
init_cnn,
15+
instantiate_module,
16+
is_module,
17+
nn_seq,
18+
)
1219
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1320

1421
__all__ = [
@@ -25,7 +32,7 @@
2532
}
2633

2734

28-
nnModule = Union[type[nn.Module], Callable[[], nn.Module]]
35+
nnModule = Union[Type[nn.Module], Callable[[], nn.Module]]
2936

3037

3138
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
@@ -36,8 +43,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3643
num_classes: int = 1000
3744
block: Union[nnModule, str] = BasicBlock
3845
conv_layer: Union[nnModule, str] = ConvBnAct
39-
block_sizes: list[int] = [64, 128, 256, 512]
40-
layers: list[int] = [2, 2, 2, 2]
46+
block_sizes: List[int] = [64, 128, 256, 512]
47+
layers: List[int] = [2, 2, 2, 2]
4148
norm: Union[nnModule, str] = nn.BatchNorm2d
4249
act_fn: Union[nnModule, str] = nn.ReLU
4350
pool: Union[nnModule, str, None] = None
@@ -52,15 +59,16 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
5259
bn_1st: bool = True
5360
zero_bn: bool = True
5461
stem_stride_on: int = 0
55-
stem_sizes: list[int] = [64]
62+
stem_sizes: List[int] = [64]
5663
stem_pool: Union[nnModule, str, None] = partial(
5764
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
5865
)
5966
stem_bn_end: bool = False
6067

6168
@field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool")
6269
def set_modules( # pylint: disable=no-self-argument
63-
cls, value: Union[nnModule, str],
70+
cls,
71+
value: Union[nnModule, str],
6472
) -> nnModule:
6573
"""Check values, if string, convert to nn.Module."""
6674
if is_module(value):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
6977

7078
@field_validator("se", "sa")
7179
def set_se( # pylint: disable=no-self-argument
72-
cls, value: Union[bool, nnModule, str], info: FieldValidationInfo,
80+
cls,
81+
value: Union[bool, nnModule, str],
82+
info: FieldValidationInfo,
7383
) -> nnModule:
7484
if isinstance(value, (int, bool)):
7585
return DEFAULT_SE_SA[info.field_name]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154164

155165

156166
def make_body(
157-
cfg: ModelCfg,
158-
layer_constructor: Callable[[ModelCfg, int], nn.Sequential] = make_layer,
167+
cfg: ModelCfg,
168+
layer_constructor: Callable[[ModelCfg, int], nn.Sequential] = make_layer,
159169
) -> nn.Sequential:
160170
"""Create model body."""
161171
if hasattr(cfg, "make_layer"):
@@ -203,7 +213,7 @@ def from_cfg(cls, cfg: ModelCfg):
203213

204214
@classmethod
205215
def create_model(
206-
cls, cfg: Optional[ModelCfg] = None, **kwargs: dict[str, Any]
216+
cls, cfg: Optional[ModelCfg] = None, **kwargs: Dict[str, Any]
207217
) -> nn.Sequential:
208218
if cfg:
209219
return cls(**cfg.model_dump(exclude_none=True))()
@@ -226,9 +236,9 @@ def __call__(self) -> nn.Sequential:
226236

227237

228238
class ResNet34(ModelConstructor):
229-
layers: list[int] = [3, 4, 6, 3]
239+
layers: List[int] = [3, 4, 6, 3]
230240

231241

232242
class ResNet50(ResNet34):
233-
block: type[nn.Module] = BottleneckBlock
234-
block_sizes: list[int] = [256, 512, 1024, 2048]
243+
block: Type[nn.Module] = BottleneckBlock
244+
block_sizes: List[int] = [256, 512, 1024, 2048]

src/model_constructor/mxresnet.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
from typing import List, Type
12
from torch import nn
23

34
from .xresnet import XResNet
45

56

67
class MxResNet(XResNet):
7-
stem_sizes: list[int] = [3, 32, 64, 64]
8-
act_fn: type[nn.Module] = nn.Mish
8+
stem_sizes: List[int] = [3, 32, 64, 64]
9+
act_fn: Type[nn.Module] = nn.Mish
910

1011

1112
class MxResNet34(MxResNet):
12-
layers: list[int] = [3, 4, 6, 3]
13+
layers: List[int] = [3, 4, 6, 3]
1314

1415

1516
class MxResNet50(MxResNet34):
1617
expansion: int = 4
17-
block_sizes: list[int] = [256, 512, 1024, 2048]
18+
block_sizes: List[int] = [256, 512, 1024, 2048]

src/model_constructor/universal_blocks.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional
1+
from typing import Callable, List, Optional, Type
22

33
import torch
44
from torch import nn
@@ -26,8 +26,8 @@ def __init__(
2626
in_channels: int,
2727
mid_channels: int,
2828
stride: int = 1,
29-
conv_layer: type[ConvBnAct] = ConvBnAct,
30-
act_fn: type[nn.Module] = nn.ReLU,
29+
conv_layer: Type[ConvBnAct] = ConvBnAct,
30+
act_fn: Type[nn.Module] = nn.ReLU,
3131
zero_bn: bool = True,
3232
bn_1st: bool = True,
3333
groups: int = 1,
@@ -150,16 +150,16 @@ def __init__(
150150
in_channels: int,
151151
mid_channels: int,
152152
stride: int = 1,
153-
conv_layer: type[ConvBnAct] = ConvBnAct,
154-
act_fn: type[nn.Module] = nn.ReLU,
153+
conv_layer: Type[ConvBnAct] = ConvBnAct,
154+
act_fn: Type[nn.Module] = nn.ReLU,
155155
zero_bn: bool = True,
156156
bn_1st: bool = True,
157157
groups: int = 1,
158158
dw: bool = False,
159159
div_groups: Optional[int] = None,
160160
pool: Optional[Callable[[], nn.Module]] = None,
161-
se: Optional[type[nn.Module]] = None,
162-
sa: Optional[type[nn.Module]] = None,
161+
se: Optional[Type[nn.Module]] = None,
162+
sa: Optional[Type[nn.Module]] = None,
163163
):
164164
super().__init__()
165165
# pool defined at ModelConstructor.
@@ -265,7 +265,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265265
def make_stem(cfg: ModelCfg) -> nn.Sequential:
266266
"""Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
267267
len_stem = len(cfg.stem_sizes)
268-
stem: list[tuple[str, nn.Module]] = [
268+
stem: ListStrMod = [
269269
(
270270
f"conv_{i}",
271271
cfg.conv_layer(
@@ -341,11 +341,11 @@ class XResNet(ModelConstructor):
341341
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer
342342
make_body: Callable[[ModelCfg], ModSeq] = make_body
343343
make_head: Callable[[ModelCfg], ModSeq] = make_head
344-
block: type[nn.Module] = XResBlock
344+
block: Type[nn.Module] = XResBlock
345345

346346

347347
class XResNet34(XResNet):
348-
layers: list[int] = [3, 4, 6, 3]
348+
layers: List[int] = [3, 4, 6, 3]
349349

350350

351351
class XResNet50(XResNet34):
@@ -357,13 +357,13 @@ class YaResNet(XResNet):
357357
YaResBlock, Mish activation, custom stem.
358358
"""
359359

360-
block: type[nn.Module] = YaResBlock
361-
stem_sizes: list[int] = [3, 32, 64, 64]
362-
act_fn: type[nn.Module] = nn.Mish
360+
block: Type[nn.Module] = YaResBlock
361+
stem_sizes: List[int] = [3, 32, 64, 64]
362+
act_fn: Type[nn.Module] = nn.Mish
363363

364364

365365
class YaResNet34(YaResNet):
366-
layers: list[int] = [3, 4, 6, 3]
366+
layers: List[int] = [3, 4, 6, 3]
367367

368368

369369
class YaResNet50(YaResNet34):

0 commit comments

Comments
 (0)