Skip to content

Commit cad905e

Browse files
committed
black
1 parent baaa733 commit cad905e

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/model_constructor/model_constructor.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
from torch import nn
88

99
from .blocks import BasicBlock, BottleneckBlock
10-
from .helpers import (Cfg, ListStrMod, ModSeq, init_cnn, instantiate_module,
11-
is_module, nn_seq)
10+
from .helpers import (
11+
Cfg,
12+
ListStrMod,
13+
ModSeq,
14+
init_cnn,
15+
instantiate_module,
16+
is_module,
17+
nn_seq,
18+
)
1219
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
1320

1421
__all__ = [
@@ -60,7 +67,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
6067

6168
@field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool")
6269
def set_modules( # pylint: disable=no-self-argument
63-
cls, value: Union[nnModule, str],
70+
cls,
71+
value: Union[nnModule, str],
6472
) -> nnModule:
6573
"""Check values, if string, convert to nn.Module."""
6674
if is_module(value):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
6977

7078
@field_validator("se", "sa")
7179
def set_se( # pylint: disable=no-self-argument
72-
cls, value: Union[bool, nnModule, str], info: FieldValidationInfo,
80+
cls,
81+
value: Union[bool, nnModule, str],
82+
info: FieldValidationInfo,
7383
) -> nnModule:
7484
if isinstance(value, (int, bool)):
7585
return DEFAULT_SE_SA[info.field_name]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154164

155165

156166
def make_body(
157-
cfg: ModelCfg,
158-
layer_constructor: Callable[[ModelCfg, int], nn.Sequential] = make_layer,
167+
cfg: ModelCfg,
168+
layer_constructor: Callable[[ModelCfg, int], nn.Sequential] = make_layer,
159169
) -> nn.Sequential:
160170
"""Create model body."""
161171
if hasattr(cfg, "make_layer"):

0 commit comments

Comments
 (0)