Skip to content

Commit 161dfa5

Browse files
authored
Merge pull request #10 from HumanCompatibleAI/add_generator_minimal
Add generator (minimal)
2 parents b354cc3 + c1fca19 commit 161dfa5

File tree

5 files changed

+316
-1
lines changed

5 files changed

+316
-1
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ git+https://github.com/openai/gym3.git@4c38246
1818
procgen @ git+https://github.com/JacobPfau/procgenAISC.git@7821f2c00b
1919
# Revert to this older version because some library won't work otherwise
2020
protobuf==3.19
21+
git+https://github.com/dfilan/vegans.git@76a3c45
22+
Pillow==9.2.0
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch as th
2+
import torch.nn as nn
3+
4+
# TODO(df): add test on initialization that shit has the right shape?
5+
6+
7+
class Small21To84Generator(nn.Module):
8+
"""
9+
Small generative model that takes 21 x 21 noise to an 84 x 84 image.
10+
"""
11+
12+
def __init__(self, latent_shape, data_shape):
13+
super(Small21To84Generator, self).__init__()
14+
self.hidden_part = nn.Sequential(
15+
nn.Conv2d(latent_shape[0], 32, kernel_size=3, padding=1),
16+
nn.LeakyReLU(0.1),
17+
nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2),
18+
nn.LeakyReLU(0.1),
19+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
20+
nn.LeakyReLU(0.1),
21+
nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2),
22+
nn.ReLU(),
23+
)
24+
self.output = nn.Conv2d(32, data_shape[0], kernel_size=3, padding=1)
25+
26+
def forward(self, x):
27+
x = self.hidden_part(x)
28+
x = self.output(x)
29+
return x
30+
31+
32+
class SmallFourTo64Generator(nn.Module):
33+
"""
34+
Small generative model that takes 4 x 4 noise to a 64 x 64 image.
35+
36+
Of use for generative modelling of procgen rollouts.
37+
"""
38+
39+
def __init__(self, latent_shape, data_shape):
40+
super(SmallFourTo64Generator, self).__init__()
41+
self.hidden_part = nn.Sequential(
42+
nn.ConvTranspose2d(latent_shape[0], 32, kernel_size=4, padding=1, stride=2),
43+
# now 8x8
44+
nn.LeakyReLU(0.1),
45+
nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2),
46+
# now 16x16
47+
nn.LeakyReLU(0.1),
48+
nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2),
49+
# now 32x32
50+
nn.LeakyReLU(0.1),
51+
nn.ConvTranspose2d(32, 32, kernel_size=4, padding=1, stride=2),
52+
# now 64x64
53+
nn.LeakyReLU(0.1),
54+
)
55+
self.output = nn.Conv2d(32, data_shape[0], kernel_size=3, padding=1)
56+
57+
def forward(self, x):
58+
x = self.hidden_part(x)
59+
x = self.output(x)
60+
return x
61+
62+
63+
class DCGanFourTo64Generator(nn.Module):
64+
"""
65+
DCGAN-based generative model that takes a 1-D latent vector to a 64x64 image.
66+
67+
Of use for generative modelling of procgen rollouts.
68+
"""
69+
70+
def __init__(self, latent_shape, data_shape):
71+
super(DCGanFourTo64Generator, self).__init__()
72+
self.project = nn.Linear(latent_shape[0], 1024 * 4 * 4)
73+
self.conv_body = nn.Sequential(
74+
nn.BatchNorm2d(1024),
75+
nn.ConvTranspose2d(1024, 512, kernel_size=4, padding=1, stride=2),
76+
# now 8x8
77+
nn.LeakyReLU(0.1),
78+
nn.BatchNorm2d(512),
79+
nn.ConvTranspose2d(512, 256, kernel_size=4, padding=1, stride=2),
80+
# now 16x16
81+
nn.LeakyReLU(0.1),
82+
nn.BatchNorm2d(256),
83+
nn.ConvTranspose2d(256, 128, kernel_size=4, padding=1, stride=2),
84+
# now 32x32
85+
nn.LeakyReLU(0.1),
86+
nn.ConvTranspose2d(128, data_shape[0], kernel_size=4, padding=1, stride=2),
87+
# now 64x64
88+
nn.LeakyReLU(0.1),
89+
)
90+
91+
def forward(self, x):
92+
batch_size = x.shape[0]
93+
x = self.project(x)
94+
x = th.reshape(x, (batch_size, 1024, 4, 4))
95+
x = nn.functional.leaky_relu(x, negative_slope=0.1)
96+
x = self.conv_body(x)
97+
return x
98+
99+
100+
class SmallWassersteinCritic(nn.Module):
101+
"""
102+
Small critic for use in the Wasserstein GAN algorithm.
103+
"""
104+
105+
def __init__(self, data_shape):
106+
super(SmallWassersteinCritic, self).__init__()
107+
self.hidden_part = nn.Sequential(
108+
nn.Conv2d(data_shape[0], 32, kernel_size=3, padding=1),
109+
nn.LeakyReLU(0.1),
110+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
111+
nn.LeakyReLU(0.1),
112+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
113+
nn.LeakyReLU(0.1),
114+
nn.AdaptiveAvgPool2d(1),
115+
nn.Flatten(),
116+
nn.Linear(32, 1),
117+
)
118+
self.output = nn.Identity()
119+
120+
def forward(self, x):
121+
x = self.hidden_part(x)
122+
x = self.output(x)
123+
return x
124+
125+
126+
class DCGanWassersteinCritic(nn.Module):
127+
"""
128+
Wasserstein-GAN critic based off the DCGAN architecture.
129+
"""
130+
131+
def __init__(self, data_shape):
132+
super(DCGanWassersteinCritic, self).__init__()
133+
self.network = nn.Sequential(
134+
nn.Conv2d(data_shape[0], 128, kernel_size=4, padding=1, stride=2),
135+
# now 32 x 32
136+
nn.LeakyReLU(0.1),
137+
nn.Conv2d(128, 256, kernel_size=4, padding=1, stride=2),
138+
# now 16 x 16
139+
nn.LeakyReLU(0.1),
140+
nn.LayerNorm([256, 16, 16]),
141+
nn.Conv2d(256, 512, kernel_size=4, padding=1, stride=2),
142+
# now 8 x 8
143+
nn.LeakyReLU(0.1),
144+
nn.LayerNorm([512, 8, 8]),
145+
nn.Conv2d(512, 1024, kernel_size=4, padding=1, stride=2),
146+
# now 4 x 4
147+
nn.LeakyReLU(0.1),
148+
nn.LayerNorm([1024, 4, 4]),
149+
nn.AdaptiveAvgPool2d(1),
150+
nn.Flatten(),
151+
nn.Linear(1024, 1),
152+
)
153+
154+
def forward(self, x):
155+
return self.network(x)

src/reward_preprocessing/interpret.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def interpret(
5050
sacred ingredient 'common' in imitation.scripts.common.
5151
reward_path: Path to the learned supervised reward net.
5252
rollout_path:
53-
Rollouts to use vor dataset visualization, dimensionality
53+
Rollouts to use for dataset visualization, dimensionality
5454
reduction, and determining the shape of the features.
5555
limit_num_obs:
5656
Limit how many of the transitions from `rollout_path` are used for
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Configuration settings for train_gan, training a generative model of transitions."""
2+
3+
import sacred
4+
import vegans.GAN
5+
6+
from reward_preprocessing.generative_modelling import gen_models
7+
8+
train_gan_ex = sacred.Experiment("train_gan")
9+
10+
11+
@train_gan_ex.config
12+
def train_gan_defaults():
13+
generator_class = gen_models.Small21To84Generator
14+
discriminator_class = gen_models.SmallWassersteinCritic
15+
gan_algorithm = vegans.GAN.WassersteinGAN
16+
optim_kwargs = {
17+
"Generator": {"lr": 5e-4},
18+
"Adversary": {"lr": 1e-4},
19+
} # keyword arguments for GAN optimizer
20+
num_training_epochs = 50
21+
batch_size = 256 # batch size for transition dataloader
22+
latent_shape = [3, 21, 21] # shape of latent vector input to generator
23+
locals() # make flake8 happy
24+
25+
26+
@train_gan_ex.named_config
27+
def procgen():
28+
generator_class = gen_models.DCGanFourTo64Generator
29+
discriminator_class = gen_models.DCGanWassersteinCritic
30+
gan_algorithm = vegans.GAN.WassersteinGANGP
31+
optim_kwargs = {
32+
"Generator": {"lr": 1e-4, "betas": (0.0, 0.9)},
33+
"Adversary": {"lr": 1e-4, "betas": (0.0, 0.9), "weight_decay": 1e-3},
34+
}
35+
num_training_epochs = 10
36+
batch_size = 128
37+
latent_shape = [100]
38+
print_every = "0.1e"
39+
save_losses_every = "0.1e"
40+
save_model_every = "1e"
41+
num_acts = 15
42+
device = "cuda"
43+
locals()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Train a generative model of transitions.
2+
3+
For use in reward function feature visualization.
4+
"""
5+
6+
import torch as th
7+
from sacred.observers import FileStorageObserver
8+
9+
from reward_preprocessing.common import utils
10+
from reward_preprocessing.scripts.config.train_gan import train_gan_ex
11+
12+
# TODO: write script to use this in feature viz.
13+
14+
15+
@train_gan_ex.main
16+
def train_gan(
17+
generator_class,
18+
discriminator_class,
19+
gan_algorithm,
20+
optim_kwargs,
21+
rollouts_paths,
22+
num_acts,
23+
num_training_epochs,
24+
batch_size,
25+
latent_shape,
26+
gan_save_path,
27+
device="cpu",
28+
ngpu=None,
29+
optimizer=th.optim.Adam,
30+
adv_steps=5,
31+
print_every="1e",
32+
save_losses_every="0.25e",
33+
save_model_every="1e",
34+
):
35+
"""Train a GAN on a set of transitions.
36+
37+
Assumes that observations are image-shaped and actions are discrete.
38+
39+
Args:
40+
generator_class: Upon initialization, takes a shape for the latent
41+
space and the shape of the transition tensors. Instantiates a
42+
network that takes latent vectors and returns transition tensors.
43+
discriminator_class: Upon initialization, takes a shape for the
44+
transition tensors. Instantiates a network that takes a
45+
transition tensor and gives it a realism score.
46+
gan_algorithm: A GAN training algorithm imported from `vegans`.
47+
optim_kwargs: A dictionary of keyword arguments for the generator and
48+
adversary networks.
49+
rollouts_paths: Path of rollouts saved by `imitation`, or list of paths.
50+
num_acts: Number of actions in the training environment.
51+
num_training_epochs: How many epochs to train the GAN for.
52+
batch_size: Number of transitions per batch to be trained on.
53+
latent_shape: Shape of the latent tensor to be fed into the generator
54+
network. Should be in (c,h,w) format.
55+
gan_save_path: Directory in which to save GAN training details.
56+
device: "cpu" or "cuda", depending on what you're training on.
57+
ngpu: Number of GPUs to train on, if training on GPUs.
58+
optimizer: torch.optim. Optimizer to train GAN with.
59+
adv_steps: Number of steps to train the adversary for for each step the
60+
generator is trained for.
61+
print_every: String specifying how many epochs should elapse between
62+
successive printings of training information.
63+
save_losses_every: String specifying how many epochs should elapse
64+
between successive savings of loss information.
65+
save_model_every: String specifying how many epochs should elapse
66+
between successive savings of the model.
67+
"""
68+
# create data loader of transitions
69+
transitions_loader = utils.rollouts_to_dataloader(
70+
rollouts_paths, num_acts, batch_size
71+
)
72+
# define gan
73+
transitions_batch = next(iter(transitions_loader))
74+
trans_shape = list(transitions_batch.shape)[1:]
75+
generator = generator_class(latent_shape, trans_shape)
76+
discriminator = discriminator_class(trans_shape)
77+
gan = gan_algorithm(
78+
generator,
79+
discriminator,
80+
z_dim=latent_shape,
81+
x_dim=trans_shape,
82+
optim=optimizer,
83+
optim_kwargs=optim_kwargs,
84+
folder=gan_save_path,
85+
device=device,
86+
ngpu=ngpu,
87+
)
88+
# print out summary
89+
gan.summary()
90+
# fit gan
91+
steps = {"Adversary": adv_steps}
92+
gan.fit(
93+
transitions_loader,
94+
batch_size=batch_size,
95+
print_every=print_every,
96+
save_losses_every=save_losses_every,
97+
save_model_every=save_model_every,
98+
epochs=num_training_epochs,
99+
steps=steps,
100+
)
101+
# save samples, return losses, save plot of losses
102+
samples, losses = gan.get_training_results()
103+
utils.save_loss_plots(losses, gan.folder)
104+
utils.visualize_samples(samples, num_acts, gan.folder)
105+
return losses
106+
107+
108+
def main_console():
109+
observer = FileStorageObserver("train_gan")
110+
train_gan_ex.observers.append(observer)
111+
train_gan_ex.run_commandline()
112+
113+
114+
if __name__ == "__main__": # pragma: no cover
115+
main_console()

0 commit comments

Comments
 (0)