7
7
from torch import nn
8
8
9
9
from .blocks import BasicBlock , BottleneckBlock
10
- from .helpers import (Cfg , ListStrMod , ModSeq , init_cnn , instantiate_module ,
11
- is_module , nn_seq )
10
+ from .helpers import (
11
+ Cfg ,
12
+ ListStrMod ,
13
+ ModSeq ,
14
+ init_cnn ,
15
+ instantiate_module ,
16
+ is_module ,
17
+ nn_seq ,
18
+ )
12
19
from .layers import ConvBnAct , SEModule , SimpleSelfAttention
13
20
14
21
__all__ = [
@@ -60,7 +67,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
60
67
61
68
@field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
62
69
def set_modules ( # pylint: disable=no-self-argument
63
- cls , value : Union [nnModule , str ],
70
+ cls ,
71
+ value : Union [nnModule , str ],
64
72
) -> nnModule :
65
73
"""Check values, if string, convert to nn.Module."""
66
74
if is_module (value ):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
69
77
70
78
@field_validator ("se" , "sa" )
71
79
def set_se ( # pylint: disable=no-self-argument
72
- cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
80
+ cls ,
81
+ value : Union [bool , nnModule , str ],
82
+ info : FieldValidationInfo ,
73
83
) -> nnModule :
74
84
if isinstance (value , (int , bool )):
75
85
return DEFAULT_SE_SA [info .field_name ]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154
164
155
165
156
166
def make_body (
157
- cfg : ModelCfg ,
158
- layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
167
+ cfg : ModelCfg ,
168
+ layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
159
169
) -> nn .Sequential :
160
170
"""Create model body."""
161
171
if hasattr (cfg , "make_layer" ):
0 commit comments