Skip to content

Commit 91c3318

Browse files
authored
Merge pull request #96 from ayasyrev/refactor_modules_cfg
Refactor modules cfg
2 parents e3d713d + 7d6ffe4 commit 91c3318

File tree

14 files changed

+362
-331
lines changed

14 files changed

+362
-331
lines changed

.flake8

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
select = C,E,F,W
33
max-complexity = 10
44
max-line-length = 120
5-
6-
application-import-names = model_constructor
7-
import-order-style = google
5+
disable-noqa = True
6+
application-import-names = model_constructor, tests
7+
import-order-style = google
8+
per-file-ignores =
9+
# imported but unused
10+
__init__.py: F401

src/model_constructor/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from model_constructor.convmixer import ConvMixer # noqa F401
2-
from model_constructor.model_constructor import (
3-
ModelConstructor,
4-
ModelCfg,
5-
) # noqa F401
6-
7-
from model_constructor.version import __version__ # noqa F401
1+
from .convmixer import ConvMixer
2+
from .model_constructor import ModelConstructor, ModelCfg
3+
from .version import __version__

src/model_constructor/activations.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,24 @@
55
from torch.nn import Mish
66

77

8-
__all__ = ['mish', 'Mish', 'mish_jit', 'MishJit', 'mish_jit_fwd', 'mish_jit_bwd', 'MishJitAutoFn', 'mish_me', 'MishMe',
9-
'hard_mish_jit', 'HardMishJit', 'hard_mish_jit_fwd', 'hard_mish_jit_bwd', 'HardMishJitAutoFn',
10-
'hard_mish_me', 'HardMishMe']
8+
__all__ = [
9+
"mish",
10+
"Mish",
11+
"mish_jit",
12+
"MishJit",
13+
"mish_jit_fwd",
14+
"mish_jit_bwd",
15+
"MishJitAutoFn",
16+
"mish_me",
17+
"MishMe",
18+
"hard_mish_jit",
19+
"HardMishJit",
20+
"hard_mish_jit_fwd",
21+
"hard_mish_jit_bwd",
22+
"HardMishJitAutoFn",
23+
"hard_mish_me",
24+
"HardMishMe",
25+
]
1126

1227

1328
def mish(x, inplace: bool = False):
@@ -40,7 +55,8 @@ def mish_jit(x, _inplace: bool = False):
4055
class MishJit(nn.Module):
4156
def __init__(self, inplace: bool = False):
4257
"""Jit version of Mish.
43-
Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681"""
58+
Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
59+
"""
4460
super(MishJit, self).__init__()
4561

4662
def forward(self, x):
@@ -61,8 +77,9 @@ def mish_jit_bwd(x, grad_output):
6177

6278

6379
class MishJitAutoFn(torch.autograd.Function):
64-
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
80+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
6581
A memory efficient, jit scripted variant of Mish"""
82+
6683
@staticmethod
6784
def forward(ctx, x):
6885
ctx.save_for_backward(x)
@@ -79,8 +96,9 @@ def mish_me(x, inplace=False):
7996

8097

8198
class MishMe(nn.Module):
82-
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
99+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
83100
A memory efficient, jit scripted variant of Mish"""
101+
84102
def __init__(self, inplace: bool = False):
85103
super(MishMe, self).__init__()
86104

@@ -90,18 +108,19 @@ def forward(self, x):
90108

91109
@torch.jit.script
92110
def hard_mish_jit(x, inplace: bool = False):
93-
""" Hard Mish
111+
"""Hard Mish
94112
Experimental, based on notes by Mish author Diganta Misra at
95113
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
96114
"""
97115
return 0.5 * x * (x + 2).clamp(min=0, max=2)
98116

99117

100118
class HardMishJit(nn.Module):
101-
""" Hard Mish
119+
"""Hard Mish
102120
Experimental, based on notes by Mish author Diganta Misra at
103121
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
104122
"""
123+
105124
def __init__(self, inplace: bool = False):
106125
super(HardMishJit, self).__init__()
107126

@@ -116,16 +135,17 @@ def hard_mish_jit_fwd(x):
116135

117136
@torch.jit.script
118137
def hard_mish_jit_bwd(x, grad_output):
119-
m = torch.ones_like(x) * (x >= -2.)
120-
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
138+
m = torch.ones_like(x) * (x >= -2.0)
139+
m = torch.where((x >= -2.0) & (x <= 0.0), x + 1.0, m)
121140
return grad_output * m
122141

123142

124143
class HardMishJitAutoFn(torch.autograd.Function):
125-
""" A memory efficient, jit scripted variant of Hard Mish
144+
"""A memory efficient, jit scripted variant of Hard Mish
126145
Experimental, based on notes by Mish author Diganta Misra at
127146
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
128147
"""
148+
129149
@staticmethod
130150
def forward(ctx, x):
131151
ctx.save_for_backward(x)
@@ -142,10 +162,11 @@ def hard_mish_me(x, inplace: bool = False):
142162

143163

144164
class HardMishMe(nn.Module):
145-
""" A memory efficient, jit scripted variant of Hard Mish
165+
"""A memory efficient, jit scripted variant of Hard Mish
146166
Experimental, based on notes by Mish author Diganta Misra at
147167
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
148168
"""
169+
149170
def __init__(self, inplace: bool = False):
150171
super(HardMishMe, self).__init__()
151172

src/model_constructor/base_constructor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Used in examples.
44
# first implementation of xresnet - inspired by fastai version.
55
from collections import OrderedDict
6-
from functools import partial
76

87
import torch.nn as nn
98

src/model_constructor/blocks.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Callable, Union
2+
3+
import torch
4+
from torch import nn
5+
6+
from .helpers import ListStrMod, nn_seq
7+
from .layers import ConvBnAct, get_act
8+
9+
10+
class BasicBlock(nn.Module):
11+
"""Basic Resnet block.
12+
Configurable - can use pool to reduce at identity path, change act etc."""
13+
14+
def __init__(
15+
self,
16+
in_channels: int,
17+
out_channels: int,
18+
stride: int = 1,
19+
conv_layer: type[ConvBnAct] = ConvBnAct,
20+
act_fn: type[nn.Module] = nn.ReLU,
21+
zero_bn: bool = True,
22+
bn_1st: bool = True,
23+
groups: int = 1,
24+
dw: bool = False,
25+
div_groups: Union[None, int] = None,
26+
pool: Union[Callable[[], nn.Module], None] = None,
27+
se: Union[nn.Module, None] = None,
28+
sa: Union[nn.Module, None] = None,
29+
):
30+
super().__init__()
31+
# pool defined at ModelConstructor.
32+
if div_groups is not None: # check if groups != 1 and div_groups
33+
groups = int(out_channels / div_groups)
34+
layers: ListStrMod = [
35+
(
36+
"conv_0",
37+
conv_layer(
38+
in_channels,
39+
out_channels,
40+
3,
41+
stride=stride,
42+
act_fn=act_fn,
43+
bn_1st=bn_1st,
44+
groups=in_channels if dw else groups,
45+
),
46+
),
47+
(
48+
"conv_1",
49+
conv_layer(
50+
out_channels,
51+
out_channels,
52+
3,
53+
zero_bn=zero_bn,
54+
act_fn=False,
55+
bn_1st=bn_1st,
56+
groups=out_channels if dw else groups,
57+
),
58+
),
59+
]
60+
if se:
61+
layers.append(("se", se(out_channels)))
62+
if sa:
63+
layers.append(("sa", sa(out_channels)))
64+
self.convs = nn_seq(layers)
65+
if stride != 1 or in_channels != out_channels:
66+
id_layers: ListStrMod = []
67+
if (
68+
stride != 1 and pool is not None
69+
): # if pool - reduce by pool else stride 2 art id_conv
70+
id_layers.append(("pool", pool()))
71+
if in_channels != out_channels or (stride != 1 and pool is None):
72+
id_layers.append(
73+
(
74+
"id_conv",
75+
conv_layer(
76+
in_channels,
77+
out_channels,
78+
1,
79+
stride=1 if pool else stride,
80+
act_fn=False,
81+
),
82+
)
83+
)
84+
self.id_conv = nn_seq(id_layers)
85+
else:
86+
self.id_conv = None
87+
self.act_fn = get_act(act_fn)
88+
89+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
90+
identity = self.id_conv(x) if self.id_conv is not None else x
91+
return self.act_fn(self.convs(x) + identity)
92+
93+
94+
class BottleneckBlock(nn.Module):
95+
"""Bottleneck Resnet block.
96+
Configurable - can use pool to reduce at identity path, change act etc."""
97+
98+
def __init__(
99+
self,
100+
in_channels: int,
101+
out_channels: int,
102+
stride: int = 1,
103+
expansion: int = 4,
104+
conv_layer: type[ConvBnAct] = ConvBnAct,
105+
act_fn: type[nn.Module] = nn.ReLU,
106+
zero_bn: bool = True,
107+
bn_1st: bool = True,
108+
groups: int = 1,
109+
dw: bool = False,
110+
div_groups: Union[None, int] = None,
111+
pool: Union[Callable[[], nn.Module], None] = None,
112+
se: Union[nn.Module, None] = None,
113+
sa: Union[nn.Module, None] = None,
114+
):
115+
super().__init__()
116+
# pool defined at ModelConstructor.
117+
mid_channels = out_channels // expansion
118+
if div_groups is not None: # check if groups != 1 and div_groups
119+
groups = int(mid_channels / div_groups)
120+
layers: ListStrMod = [
121+
(
122+
"conv_0",
123+
conv_layer(
124+
in_channels,
125+
mid_channels,
126+
1,
127+
act_fn=act_fn,
128+
bn_1st=bn_1st,
129+
),
130+
),
131+
(
132+
"conv_1",
133+
conv_layer(
134+
mid_channels,
135+
mid_channels,
136+
3,
137+
stride=stride,
138+
act_fn=act_fn,
139+
bn_1st=bn_1st,
140+
groups=mid_channels if dw else groups,
141+
),
142+
),
143+
(
144+
"conv_2",
145+
conv_layer(
146+
mid_channels,
147+
out_channels,
148+
1,
149+
zero_bn=zero_bn,
150+
act_fn=False,
151+
bn_1st=bn_1st,
152+
),
153+
),
154+
]
155+
if se:
156+
layers.append(("se", se(out_channels)))
157+
if sa:
158+
layers.append(("sa", sa(out_channels)))
159+
self.convs = nn_seq(layers)
160+
if stride != 1 or in_channels != out_channels:
161+
id_layers: ListStrMod = []
162+
if (
163+
stride != 1 and pool is not None
164+
): # if pool - reduce by pool else stride 2 art id_conv
165+
id_layers.append(("pool", pool()))
166+
if in_channels != out_channels or (stride != 1 and pool is None):
167+
id_layers.append(
168+
(
169+
"id_conv",
170+
conv_layer(
171+
in_channels,
172+
out_channels,
173+
1,
174+
stride=1 if pool else stride,
175+
act_fn=False,
176+
),
177+
)
178+
)
179+
self.id_conv = nn_seq(id_layers)
180+
else:
181+
self.id_conv = None
182+
self.act_fn = get_act(act_fn)
183+
184+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
185+
identity = self.id_conv(x) if self.id_conv is not None else x
186+
return self.act_fn(self.convs(x) + identity)

src/model_constructor/convmixer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(
6767
bn_1st: bool = False,
6868
pre_act: bool = False,
6969
):
70-
7170
conv_layer: List[tuple[str, nn.Module]] = [
7271
(
7372
"conv",

0 commit comments

Comments
 (0)