|
| 1 | +import torch as th |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +# TODO(df): add test on initialization that shit has the right shape? |
| 5 | + |
| 6 | + |
| 7 | +class Small21To84Generator(nn.Module): |
| 8 | + """ |
| 9 | + Small generative model that takes 21 x 21 noise to an 84 x 84 image. |
| 10 | + """ |
| 11 | + |
| 12 | + def __init__(self, latent_shape, data_shape): |
| 13 | + super(Small21To84Generator, self).__init__() |
| 14 | + self.hidden_part = nn.Sequential( |
| 15 | + nn.Conv2d(latent_shape[0], 32, kernel_size=3, padding=1), |
| 16 | + nn.LeakyReLU(0.1), |
| 17 | + nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2), |
| 18 | + nn.LeakyReLU(0.1), |
| 19 | + nn.Conv2d(32, 32, kernel_size=3, padding=1), |
| 20 | + nn.LeakyReLU(0.1), |
| 21 | + nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2), |
| 22 | + nn.ReLU(), |
| 23 | + ) |
| 24 | + self.output = nn.Conv2d(32, data_shape[0], kernel_size=3, padding=1) |
| 25 | + |
| 26 | + def forward(self, x): |
| 27 | + x = self.hidden_part(x) |
| 28 | + x = self.output(x) |
| 29 | + return x |
| 30 | + |
| 31 | + |
| 32 | +class SmallFourTo64Generator(nn.Module): |
| 33 | + """ |
| 34 | + Small generative model that takes 4 x 4 noise to a 64 x 64 image. |
| 35 | +
|
| 36 | + Of use for generative modelling of procgen rollouts. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__(self, latent_shape, data_shape): |
| 40 | + super(SmallFourTo64Generator, self).__init__() |
| 41 | + self.hidden_part = nn.Sequential( |
| 42 | + nn.ConvTranspose2d(latent_shape[0], 32, kernel_size=4, padding=1, stride=2), |
| 43 | + # now 8x8 |
| 44 | + nn.LeakyReLU(0.1), |
| 45 | + nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2), |
| 46 | + # now 16x16 |
| 47 | + nn.LeakyReLU(0.1), |
| 48 | + nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2), |
| 49 | + # now 32x32 |
| 50 | + nn.LeakyReLU(0.1), |
| 51 | + nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2), |
| 52 | + # now 64x64 |
| 53 | + nn.LeakyReLU(0.1), |
| 54 | + ) |
| 55 | + self.output = nn.Conv2d(32, data_shape[0], kernel_size=3, padding=1) |
| 56 | + |
| 57 | + def forward(self, x): |
| 58 | + x = self.hidden_part(x) |
| 59 | + x = self.output(x) |
| 60 | + return x |
| 61 | + |
| 62 | + |
| 63 | +class DCGanFourTo64Generator(nn.Module): |
| 64 | + """ |
| 65 | + DCGAN-based generative model that takes a 1-D latent vector to a 64x64 image. |
| 66 | +
|
| 67 | + Of use for generative modelling of procgen rollouts. |
| 68 | + """ |
| 69 | + |
| 70 | + def __init__(self, latent_shape, data_shape): |
| 71 | + super(DCGanFourTo64Generator, self).__init__() |
| 72 | + self.project = nn.Linear(latent_shape[0], 1024 * 4 * 4) |
| 73 | + self.conv_body = nn.Sequential( |
| 74 | + nn.BatchNorm2d(1024), |
| 75 | + nn.ConvTranspose2d(1024, 512, kernel_size=4, padding=1, stride=2), |
| 76 | + # now 8x8 |
| 77 | + nn.LeakyReLU(0.1), |
| 78 | + nn.BatchNorm2d(512), |
| 79 | + nn.ConvTranspose2d(512, 256, kernel_size=4, padding=1, stride=2), |
| 80 | + # now 16x16 |
| 81 | + nn.LeakyReLU(0.1), |
| 82 | + nn.BatchNorm2d(256), |
| 83 | + nn.ConvTranspose2d(256, 128, kernel_size=4, padding=1, stride=2), |
| 84 | + # now 32x32 |
| 85 | + nn.LeakyReLU(0.1), |
| 86 | + nn.ConvTranspose2d(128, data_shape[0], kernel_size=4, padding=1, stride=2), |
| 87 | + # now 64x64 |
| 88 | + nn.LeakyReLU(0.1), |
| 89 | + ) |
| 90 | + |
| 91 | + def forward(self, x): |
| 92 | + batch_size = x.shape[0] |
| 93 | + x = self.project(x) |
| 94 | + x = th.reshape(x, (batch_size, 1024, 4, 4)) |
| 95 | + x = nn.functional.leaky_relu(x, negative_slope=0.1) |
| 96 | + x = self.conv_body(x) |
| 97 | + return x |
| 98 | + |
| 99 | + |
| 100 | +class SmallWassersteinCritic(nn.Module): |
| 101 | + """ |
| 102 | + Small critic for use in the Wasserstein GAN algorithm. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self, data_shape): |
| 106 | + super(SmallWassersteinCritic, self).__init__() |
| 107 | + self.hidden_part = nn.Sequential( |
| 108 | + nn.Conv2d(data_shape[0], 32, kernel_size=3, padding=1), |
| 109 | + nn.LeakyReLU(0.1), |
| 110 | + nn.Conv2d(32, 32, kernel_size=3, padding=1), |
| 111 | + nn.LeakyReLU(0.1), |
| 112 | + nn.Conv2d(32, 32, kernel_size=3, padding=1), |
| 113 | + nn.LeakyReLU(0.1), |
| 114 | + nn.AdaptiveAvgPool2d(1), |
| 115 | + nn.Flatten(), |
| 116 | + nn.Linear(32, 1), |
| 117 | + ) |
| 118 | + self.output = nn.Identity() |
| 119 | + |
| 120 | + def forward(self, x): |
| 121 | + x = self.hidden_part(x) |
| 122 | + x = self.output(x) |
| 123 | + return x |
| 124 | + |
| 125 | + |
| 126 | +class DCGanWassersteinCritic(nn.Module): |
| 127 | + """ |
| 128 | + Wasserstein-GAN critic based off the DCGAN architecture. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__(self, data_shape): |
| 132 | + super(DCGanWassersteinCritic, self).__init__() |
| 133 | + self.network = nn.Sequential( |
| 134 | + nn.Conv2d(data_shape[0], 128, kernel_size=4, padding=1, stride=2), |
| 135 | + # now 32 x 32 |
| 136 | + nn.LeakyReLU(0.1), |
| 137 | + nn.Conv2d(128, 256, kernel_size=4, padding=1, stride=2), |
| 138 | + # now 16 x 16 |
| 139 | + nn.LeakyReLU(0.1), |
| 140 | + nn.LayerNorm([256, 16, 16]), |
| 141 | + nn.Conv2d(256, 512, kernel_size=4, padding=1, stride=2), |
| 142 | + # now 8 x 8 |
| 143 | + nn.LeakyReLU(0.1), |
| 144 | + nn.LayerNorm([512, 8, 8]), |
| 145 | + nn.Conv2d(512, 1024, kernel_size=4, padding=1, stride=2), |
| 146 | + # now 4 x 4 |
| 147 | + nn.LeakyReLU(0.1), |
| 148 | + nn.LayerNorm([1024, 4, 4]), |
| 149 | + nn.AdaptiveAvgPool2d(1), |
| 150 | + nn.Flatten(), |
| 151 | + nn.Linear(1024, 1), |
| 152 | + ) |
| 153 | + |
| 154 | + def forward(self, x): |
| 155 | + return self.network(x) |
0 commit comments