Skip to content

Commit 60778b7

Browse files
committed
typing
1 parent 48a26de commit 60778b7

File tree

5 files changed

+35
-35
lines changed

5 files changed

+35
-35
lines changed

src/model_constructor/layers.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,14 @@ class SEModule(nn.Module):
258258

259259
def __init__(
260260
self,
261-
channels,
262-
reduction=16,
263-
rd_channels=None,
264-
rd_max=False,
265-
se_layer=nn.Linear,
266-
act_fn=nn.ReLU(inplace=True),
267-
use_bias=True,
268-
gate=nn.Sigmoid,
261+
channels: int,
262+
reduction: int = 16,
263+
rd_channels: Optional[int] = None,
264+
rd_max: bool = False,
265+
se_layer: type[nn.Module] = nn.Linear,
266+
act_fn: nn.Module = nn.ReLU(inplace=True),
267+
use_bias: bool = True,
268+
gate: type[nn.Module] = nn.Sigmoid,
269269
):
270270
super().__init__()
271271
reducted = max(channels // reduction, 1) # preserve zero-element tensors
@@ -286,7 +286,7 @@ def __init__(
286286
)
287287
)
288288

289-
def forward(self, x):
289+
def forward(self, x: torch.Tensor) -> torch.Tensor:
290290
bs, c, _, _ = x.shape
291291
y = self.squeeze(x).view(bs, c)
292292
y = self.excitation(y).view(bs, c, 1, 1)
@@ -298,14 +298,14 @@ class SEModuleConv(nn.Module):
298298

299299
def __init__(
300300
self,
301-
channels,
302-
reduction=16,
303-
rd_channels=None,
304-
rd_max=False,
305-
se_layer=nn.Conv2d,
306-
act_fn=nn.ReLU(inplace=True),
307-
use_bias=True,
308-
gate=nn.Sigmoid,
301+
channels: int,
302+
reduction: int = 16,
303+
rd_channels: Optional[int] = None,
304+
rd_max: bool = False,
305+
se_layer: type[nn.Module] = nn.Conv2d,
306+
act_fn: nn.Module = nn.ReLU(inplace=True),
307+
use_bias: bool = True,
308+
gate: type[nn.Module] = nn.Sigmoid,
309309
):
310310
super().__init__()
311311
# rd_channels = math.ceil(channels//reduction/8)*8
@@ -327,7 +327,7 @@ def __init__(
327327
)
328328
)
329329

330-
def forward(self, x):
330+
def forward(self, x: torch.Tensor) -> torch.Tensor:
331331
y = self.squeeze(x)
332332
y = self.excitation(y)
333333
return x * y.expand_as(x)

src/model_constructor/model_constructor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
4242
zero_bn: bool = True
4343
stem_stride_on: int = 0
4444
stem_sizes: list[int] = [64]
45-
stem_pool: Union[Callable[[], nn.Module], None] = partial(
45+
stem_pool: Optional[Callable[[], nn.Module]] = partial(
4646
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
4747
)
4848
stem_bn_end: bool = False
@@ -61,7 +61,7 @@ def __repr__(self) -> str:
6161
)
6262

6363

64-
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
64+
def make_stem(cfg: ModelCfg) -> nn.Sequential:
6565
"""Create Resnet stem."""
6666
stem: ListStrMod = [
6767
(
@@ -116,15 +116,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
116116
)
117117

118118

119-
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
119+
def make_body(cfg: ModelCfg) -> nn.Sequential:
120120
"""Create model body."""
121121
return nn_seq(
122122
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
123123
for layer_num in range(len(cfg.layers))
124124
)
125125

126126

127-
def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
127+
def make_head(cfg: ModelCfg) -> nn.Sequential:
128128
"""Create head."""
129129
head = [
130130
("pool", nn.AdaptiveAvgPool2d(1)),
@@ -138,10 +138,10 @@ class ModelConstructor(ModelCfg):
138138
"""Model constructor. As default - resnet18"""
139139

140140
init_cnn: Callable[[nn.Module], None] = init_cnn
141-
make_stem: Callable[[ModelCfg], ModSeq] = make_stem # type: ignore
142-
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer # type: ignore
143-
make_body: Callable[[ModelCfg], ModSeq] = make_body # type: ignore
144-
make_head: Callable[[ModelCfg], ModSeq] = make_head # type: ignore
141+
make_stem: Callable[[ModelCfg], ModSeq] = make_stem
142+
make_layer: Callable[[ModelCfg, int], ModSeq] = make_layer
143+
make_body: Callable[[ModelCfg], ModSeq] = make_body
144+
make_head: Callable[[ModelCfg], ModSeq] = make_head
145145

146146
@field_validator("se")
147147
def set_se( # pylint: disable=no-self-argument

src/model_constructor/universal_blocks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.id_conv = None
135135
self.act_fn = get_act(act_fn)
136136

137-
def forward(self, x: torch.Tensor): # type: ignore
137+
def forward(self, x: torch.Tensor) -> torch.Tensor:
138138
identity = self.id_conv(x) if self.id_conv is not None else x
139139
return self.act_fn(self.convs(x) + identity)
140140

@@ -255,14 +255,14 @@ def __init__(
255255
self.id_conv = None
256256
self.merge = get_act(act_fn)
257257

258-
def forward(self, x: torch.Tensor): # type: ignore
258+
def forward(self, x: torch.Tensor) -> torch.Tensor:
259259
if self.reduce:
260260
x = self.reduce(x)
261261
identity = self.id_conv(x) if self.id_conv is not None else x
262262
return self.merge(self.convs(x) + identity)
263263

264264

265-
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
265+
def make_stem(cfg: ModelCfg) -> nn.Sequential:
266266
"""Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
267267
len_stem = len(cfg.stem_sizes)
268268
stem: list[tuple[str, nn.Module]] = [
@@ -286,7 +286,7 @@ def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
286286
return nn_seq(stem)
287287

288288

289-
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
289+
def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential:
290290
"""Create layer (stage)"""
291291
# if no pool on stem - stride = 2 for first layer block in body
292292
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
@@ -316,15 +316,15 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
316316
)
317317

318318

319-
def make_body(cfg: ModelCfg) -> nn.Sequential: # type: ignore
319+
def make_body(cfg: ModelCfg) -> nn.Sequential:
320320
"""Create model body."""
321321
return nn_seq(
322322
(f"l_{layer_num}", cfg.make_layer(cfg, layer_num)) # type: ignore
323323
for layer_num in range(len(cfg.layers))
324324
)
325325

326326

327-
def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
327+
def make_head(cfg: ModelCfg) -> nn.Sequential:
328328
"""Create head."""
329329
head = [
330330
("pool", nn.AdaptiveAvgPool2d(1)),

src/model_constructor/xresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
]
1515

1616

17-
def xresnet_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
17+
def xresnet_stem(cfg: ModelCfg) -> nn.Sequential:
1818
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
1919
len_stem = len(cfg.stem_sizes)
2020
stem: ListStrMod = [

src/model_constructor/yaresnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
self.id_conv = None
101101
self.merge = get_act(act_fn)
102102

103-
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
103+
def forward(self, x: torch.Tensor) -> torch.Tensor:
104104
if self.reduce:
105105
x = self.reduce(x)
106106
identity = self.id_conv(x) if self.id_conv is not None else x
@@ -195,7 +195,7 @@ def __init__(
195195
self.id_conv = None
196196
self.merge = get_act(act_fn)
197197

198-
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
198+
def forward(self, x: torch.Tensor) -> torch.Tensor:
199199
if self.reduce:
200200
x = self.reduce(x)
201201
identity = self.id_conv(x) if self.id_conv is not None else x

0 commit comments

Comments
 (0)