Skip to content

Commit cd6dcde

Browse files
authored
Merge pull request #32 from ayasyrev/ayasyrev/issue31
Add ConvMixer
2 parents 3508c69 + e9c5db4 commit cd6dcde

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

model_constructor/convmixer.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?".
2+
# Adopted from https://github.com/tmp-iclr/convmixer
3+
from collections import OrderedDict
4+
import torch.nn as nn
5+
6+
7+
class Residual(nn.Module):
8+
def __init__(self, fn):
9+
super().__init__()
10+
self.fn = fn
11+
12+
def forward(self, x):
13+
return self.fn(x) + x
14+
15+
16+
# As original version, act_fn as argument.
17+
def ConvMixerOriginal(dim, depth,
18+
kernel_size=9, patch_size=7, n_classes=1000,
19+
act_fn=nn.GELU()):
20+
return nn.Sequential(
21+
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
22+
act_fn,
23+
nn.BatchNorm2d(dim),
24+
*[nn.Sequential(
25+
Residual(nn.Sequential(
26+
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
27+
act_fn,
28+
nn.BatchNorm2d(dim)
29+
)),
30+
nn.Conv2d(dim, dim, kernel_size=1),
31+
act_fn,
32+
nn.BatchNorm2d(dim)
33+
) for i in range(depth)],
34+
nn.AdaptiveAvgPool2d((1, 1)),
35+
nn.Flatten(),
36+
nn.Linear(dim, n_classes)
37+
)
38+
39+
40+
class ConvLayer(nn.Sequential):
41+
"""Basic conv layers block"""
42+
43+
def __init__(self, ch_in, ch_out, kernel_size, stride=1,
44+
act_fn=nn.GELU(), padding=0, groups=1,
45+
bn_1st=False, pre_act=False):
46+
47+
conv_layer = [('conv', nn.Conv2d(ch_in, ch_out, kernel_size, stride=stride,
48+
padding=padding, groups=groups))]
49+
act_bn = [
50+
('act_fn', act_fn),
51+
('bn', nn.BatchNorm2d(ch_out))
52+
]
53+
if bn_1st:
54+
act_bn.reverse()
55+
if pre_act:
56+
act_bn.insert(1, conv_layer[0])
57+
layers = act_bn
58+
else:
59+
layers = conv_layer + act_bn
60+
super().__init__(OrderedDict(layers))
61+
62+
63+
def ConvMixer(dim: int, depth: int,
64+
kernel_size: int = 9, patch_size: int = 7, n_classes: int = 1000,
65+
act_fn: nn.Module = nn.GELU(),
66+
stem_ch: int = 0, stem_ks: int = 1,
67+
bn_1st: bool = False, pre_act: bool = False) -> nn.Sequential:
68+
"""ConvMixer constructor.
69+
Adopted from https://github.com/tmp-iclr/convmixer
70+
71+
Args:
72+
dim (int): Dimention of model.
73+
depth (int): Depth of model.
74+
kernel_size (int, optional): Kernel size. Defaults to 9.
75+
patch_size (int, optional): Patch size. Defaults to 7.
76+
n_classes (int, optional): Number of classes. Defaults to 1000.
77+
act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU().
78+
stem_ch (int, optional): If not 0 - add additional 'stem' layer with atem_ch chennels. Defaults to 0.
79+
stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1.
80+
bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False.
81+
pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False.
82+
83+
Returns:
84+
nn.Sequential: nn.Model as Sequential model.
85+
"""
86+
if pre_act:
87+
bn_1st = False
88+
if stem_ch:
89+
stem = [ConvLayer(3, stem_ch, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st),
90+
ConvLayer(stem_ch, dim, kernel_size=stem_ks, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act)]
91+
else:
92+
stem = [ConvLayer(3, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)]
93+
return nn.Sequential(
94+
*stem,
95+
*[nn.Sequential(
96+
Residual(ConvLayer(dim, dim, kernel_size, groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)),
97+
ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act)) for i in range(depth)],
98+
nn.AdaptiveAvgPool2d((1, 1)),
99+
nn.Flatten(),
100+
nn.Linear(dim, n_classes)
101+
)

0 commit comments

Comments
 (0)