1
1
from collections import OrderedDict
2
2
from functools import partial
3
- from typing import Any , Callable , Optional , Union
3
+ from typing import Any , Callable , Dict , List , Optional , Union , Type
4
4
5
5
from pydantic import field_validator
6
6
from pydantic_core .core_schema import FieldValidationInfo
7
7
from torch import nn
8
8
9
9
from .blocks import BasicBlock , BottleneckBlock
10
- from .helpers import (Cfg , ListStrMod , ModSeq , init_cnn , instantiate_module ,
11
- is_module , nn_seq )
10
+ from .helpers import (
11
+ Cfg ,
12
+ ListStrMod ,
13
+ ModSeq ,
14
+ init_cnn ,
15
+ instantiate_module ,
16
+ is_module ,
17
+ nn_seq ,
18
+ )
12
19
from .layers import ConvBnAct , SEModule , SimpleSelfAttention
13
20
14
21
__all__ = [
25
32
}
26
33
27
34
28
- nnModule = Union [type [nn .Module ], Callable [[], nn .Module ]]
35
+ nnModule = Union [Type [nn .Module ], Callable [[], nn .Module ]]
29
36
30
37
31
38
class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
@@ -36,8 +43,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
36
43
num_classes : int = 1000
37
44
block : Union [nnModule , str ] = BasicBlock
38
45
conv_layer : Union [nnModule , str ] = ConvBnAct
39
- block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
40
- layers : list [int ] = [2 , 2 , 2 , 2 ]
46
+ block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
47
+ layers : List [int ] = [2 , 2 , 2 , 2 ]
41
48
norm : Union [nnModule , str ] = nn .BatchNorm2d
42
49
act_fn : Union [nnModule , str ] = nn .ReLU
43
50
pool : Union [nnModule , str , None ] = None
@@ -52,15 +59,16 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
52
59
bn_1st : bool = True
53
60
zero_bn : bool = True
54
61
stem_stride_on : int = 0
55
- stem_sizes : list [int ] = [64 ]
62
+ stem_sizes : List [int ] = [64 ]
56
63
stem_pool : Union [nnModule , str , None ] = partial (
57
64
nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
58
65
)
59
66
stem_bn_end : bool = False
60
67
61
68
@field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
62
69
def set_modules ( # pylint: disable=no-self-argument
63
- cls , value : Union [nnModule , str ],
70
+ cls ,
71
+ value : Union [nnModule , str ],
64
72
) -> nnModule :
65
73
"""Check values, if string, convert to nn.Module."""
66
74
if is_module (value ):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
69
77
70
78
@field_validator ("se" , "sa" )
71
79
def set_se ( # pylint: disable=no-self-argument
72
- cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
80
+ cls ,
81
+ value : Union [bool , nnModule , str ],
82
+ info : FieldValidationInfo ,
73
83
) -> nnModule :
74
84
if isinstance (value , (int , bool )):
75
85
return DEFAULT_SE_SA [info .field_name ]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154
164
155
165
156
166
def make_body (
157
- cfg : ModelCfg ,
158
- layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
167
+ cfg : ModelCfg ,
168
+ layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
159
169
) -> nn .Sequential :
160
170
"""Create model body."""
161
171
if hasattr (cfg , "make_layer" ):
@@ -203,7 +213,7 @@ def from_cfg(cls, cfg: ModelCfg):
203
213
204
214
@classmethod
205
215
def create_model (
206
- cls , cfg : Optional [ModelCfg ] = None , ** kwargs : dict [str , Any ]
216
+ cls , cfg : Optional [ModelCfg ] = None , ** kwargs : Dict [str , Any ]
207
217
) -> nn .Sequential :
208
218
if cfg :
209
219
return cls (** cfg .model_dump (exclude_none = True ))()
@@ -226,9 +236,9 @@ def __call__(self) -> nn.Sequential:
226
236
227
237
228
238
class ResNet34 (ModelConstructor ):
229
- layers : list [int ] = [3 , 4 , 6 , 3 ]
239
+ layers : List [int ] = [3 , 4 , 6 , 3 ]
230
240
231
241
232
242
class ResNet50 (ResNet34 ):
233
- block : type [nn .Module ] = BottleneckBlock
234
- block_sizes : list [int ] = [256 , 512 , 1024 , 2048 ]
243
+ block : Type [nn .Module ] = BottleneckBlock
244
+ block_sizes : List [int ] = [256 , 512 , 1024 , 2048 ]
0 commit comments