|
| 1 | +# Implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?". |
| 2 | +# Adopted from https://github.com/tmp-iclr/convmixer |
| 3 | +from collections import OrderedDict |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | + |
| 7 | +class Residual(nn.Module): |
| 8 | + def __init__(self, fn): |
| 9 | + super().__init__() |
| 10 | + self.fn = fn |
| 11 | + |
| 12 | + def forward(self, x): |
| 13 | + return self.fn(x) + x |
| 14 | + |
| 15 | + |
| 16 | +# As original version, act_fn as argument. |
| 17 | +def ConvMixerOriginal(dim, depth, |
| 18 | + kernel_size=9, patch_size=7, n_classes=1000, |
| 19 | + act_fn=nn.GELU()): |
| 20 | + return nn.Sequential( |
| 21 | + nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), |
| 22 | + act_fn, |
| 23 | + nn.BatchNorm2d(dim), |
| 24 | + *[nn.Sequential( |
| 25 | + Residual(nn.Sequential( |
| 26 | + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), |
| 27 | + act_fn, |
| 28 | + nn.BatchNorm2d(dim) |
| 29 | + )), |
| 30 | + nn.Conv2d(dim, dim, kernel_size=1), |
| 31 | + act_fn, |
| 32 | + nn.BatchNorm2d(dim) |
| 33 | + ) for i in range(depth)], |
| 34 | + nn.AdaptiveAvgPool2d((1, 1)), |
| 35 | + nn.Flatten(), |
| 36 | + nn.Linear(dim, n_classes) |
| 37 | + ) |
| 38 | + |
| 39 | + |
| 40 | +class ConvLayer(nn.Sequential): |
| 41 | + """Basic conv layers block""" |
| 42 | + |
| 43 | + def __init__(self, ch_in, ch_out, kernel_size, stride=1, |
| 44 | + act_fn=nn.GELU(), padding=0, groups=1, |
| 45 | + bn_1st=False, pre_act=False): |
| 46 | + |
| 47 | + conv_layer = [('conv', nn.Conv2d(ch_in, ch_out, kernel_size, stride=stride, |
| 48 | + padding=padding, groups=groups))] |
| 49 | + act_bn = [ |
| 50 | + ('act_fn', act_fn), |
| 51 | + ('bn', nn.BatchNorm2d(ch_out)) |
| 52 | + ] |
| 53 | + if bn_1st: |
| 54 | + act_bn.reverse() |
| 55 | + if pre_act: |
| 56 | + act_bn.insert(1, conv_layer[0]) |
| 57 | + layers = act_bn |
| 58 | + else: |
| 59 | + layers = conv_layer + act_bn |
| 60 | + super().__init__(OrderedDict(layers)) |
| 61 | + |
| 62 | + |
| 63 | +def ConvMixer(dim: int, depth: int, |
| 64 | + kernel_size: int = 9, patch_size: int = 7, n_classes: int = 1000, |
| 65 | + act_fn: nn.Module = nn.GELU(), |
| 66 | + stem_ch: int = 0, stem_ks: int = 1, |
| 67 | + bn_1st: bool = False, pre_act: bool = False) -> nn.Sequential: |
| 68 | + """ConvMixer constructor. |
| 69 | + Adopted from https://github.com/tmp-iclr/convmixer |
| 70 | +
|
| 71 | + Args: |
| 72 | + dim (int): Dimention of model. |
| 73 | + depth (int): Depth of model. |
| 74 | + kernel_size (int, optional): Kernel size. Defaults to 9. |
| 75 | + patch_size (int, optional): Patch size. Defaults to 7. |
| 76 | + n_classes (int, optional): Number of classes. Defaults to 1000. |
| 77 | + act_fn (nn.Module, optional): Activation function. Defaults to nn.GELU(). |
| 78 | + stem_ch (int, optional): If not 0 - add additional 'stem' layer with atem_ch chennels. Defaults to 0. |
| 79 | + stem_ks (int, optional): If stem_ch not 0 - kernel size for adittional layer. Defaults to 1. |
| 80 | + bn_1st (bool, optional): If True - BatchNorm befor activation function. Defaults to False. |
| 81 | + pre_act (bool, optional): If True - activatin function befor convolution layer. Defaults to False. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + nn.Sequential: nn.Model as Sequential model. |
| 85 | + """ |
| 86 | + if pre_act: |
| 87 | + bn_1st = False |
| 88 | + if stem_ch: |
| 89 | + stem = [ConvLayer(3, stem_ch, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st), |
| 90 | + ConvLayer(stem_ch, dim, kernel_size=stem_ks, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act)] |
| 91 | + else: |
| 92 | + stem = [ConvLayer(3, dim, kernel_size=patch_size, stride=patch_size, act_fn=act_fn, bn_1st=bn_1st)] |
| 93 | + return nn.Sequential( |
| 94 | + *stem, |
| 95 | + *[nn.Sequential( |
| 96 | + Residual(ConvLayer(dim, dim, kernel_size, groups=dim, padding="same", bn_1st=bn_1st, pre_act=pre_act)), |
| 97 | + ConvLayer(dim, dim, kernel_size=1, act_fn=act_fn, bn_1st=bn_1st, pre_act=pre_act)) for i in range(depth)], |
| 98 | + nn.AdaptiveAvgPool2d((1, 1)), |
| 99 | + nn.Flatten(), |
| 100 | + nn.Linear(dim, n_classes) |
| 101 | + ) |
0 commit comments