Skip to content

Commit f501e9e

Browse files
authored
Merge pull request #16 from HumanCompatibleAI/add_generator
Visualize using GAN
2 parents 95be3bc + 7bd27cc commit f501e9e

File tree

6 files changed

+199
-54
lines changed

6 files changed

+199
-54
lines changed

Diff for: requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pytest==6.2.5
22
black[jupyter]==22.10
33
flake8==3.9.2
4-
pytype==2021.8.24
4+
pytype==2022.10.26
55
flake8-isort==4.0.0

Diff for: src/reward_preprocessing/common/utils.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ def rollouts_to_dataloader(
120120

121121

122122
def visualize_samples(samples: np.ndarray, save_dir):
123-
"""Visualize samples from a GAN. Saves obs and next obs as png files, and takes
124-
mean over height and width dimensions to turn act into a numpy array, before
125-
saving it.
123+
"""Visualize samples from a GAN.
124+
125+
Saves obs and next obs as png files, and takes mean over height and width dimensions
126+
to turn act into a numpy array, before saving it.
126127
"""
127128
for i, transition in enumerate(samples):
128129
num_acts = transition.shape[0] - 6
@@ -209,6 +210,25 @@ def forward(self, transition_tensor: th.Tensor) -> th.Tensor:
209210
return self.rew_net(state=obs, action=act, next_state=next_obs, done=dones)
210211

211212

213+
class RewardGeneratorCombo(nn.Module):
214+
"""Composition of a generative model and a RewardNet.
215+
216+
Assumes that the RewardNet normalizes observations to [0,1].
217+
"""
218+
219+
def __init__(self, reward_net: RewardNet, generator: nn.Module):
220+
super().__init__()
221+
self.reward_net = reward_net
222+
self.generator = generator
223+
224+
def forward(self, latent_tens: th.Tensor):
225+
latent_vec = th.mean(latent_tens, dim=[2, 3])
226+
transition_tensor = self.generator(latent_vec)
227+
obs, action_vec, next_obs = tensor_to_transition(transition_tensor)
228+
done = th.zeros(action_vec.shape)
229+
return self.reward_net.forward(obs, action_vec, next_obs, done)
230+
231+
212232
def save_loss_plots(losses, save_dir):
213233
"""Save plots of generator/adversary losses over training."""
214234
fig, _ = vegans.utils.plot_losses(losses, show=False)

Diff for: src/reward_preprocessing/generative_modelling/gen_models.py

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class DCGanFourTo64Generator(nn.Module):
6969

7070
def __init__(self, latent_shape, data_shape):
7171
super(DCGanFourTo64Generator, self).__init__()
72+
# Identity op so that lucent can regularize L2 norm of input.
73+
self.latent_vec = nn.Identity()
7274
self.project = nn.Linear(latent_shape[0], 1024 * 4 * 4)
7375
self.conv_body = nn.Sequential(
7476
nn.BatchNorm2d(1024),
@@ -90,6 +92,7 @@ def __init__(self, latent_shape, data_shape):
9092

9193
def forward(self, x):
9294
batch_size = x.shape[0]
95+
x = self.latent_vec(x)
9396
x = self.project(x)
9497
x = th.reshape(x, (batch_size, 1024, 4, 4))
9598
x = nn.functional.leaky_relu(x, negative_slope=0.1)

Diff for: src/reward_preprocessing/interpret.py

+133-42
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path as osp
2-
from typing import Optional
2+
from typing import Optional, Tuple, Union
33

44
from PIL import Image
55
from imitation.scripts.common import common as common_config
@@ -14,6 +14,7 @@
1414
import wandb
1515

1616
from reward_preprocessing.common.utils import (
17+
RewardGeneratorCombo,
1718
TensorTransitionWrapper,
1819
rollouts_to_dataloader,
1920
tensor_to_transition,
@@ -40,7 +41,9 @@ def interpret(
4041
vis_type: str,
4142
layer_name: str,
4243
num_features: Optional[int],
43-
gan_path: Optional[str],
44+
gan_path: Optional[str] = None,
45+
l2_coeff: Optional[float] = None,
46+
img_save_path: Optional[str] = None,
4447
):
4548
"""Run visualization for interpretability.
4649
@@ -74,6 +77,12 @@ def interpret(
7477
Path to the GAN model. This is used to regularize the output of the
7578
visualization. If None simply visualize reward net without the use
7679
of a GAN in the pipeline.
80+
l2_coeff:
81+
Strength with which to penalize the L2 norm of generated latent vector
82+
"visualizations" of a GAN-reward model combination. If gan_path is not None,
83+
this must also not be None.
84+
img_save_path:
85+
Directory to save images in. Must end in a /. If None, do not save images.
7786
"""
7887
if limit_num_obs <= 0:
7988
raise ValueError(
@@ -82,6 +91,15 @@ def interpret(
8291
f"I don't think we actually ever want to use all so this is currently not "
8392
f"implemented."
8493
)
94+
if vis_type not in ["dataset", "traditional"]:
95+
raise ValueError(f"Unknown vis_type: {vis_type}")
96+
if vis_type == "dataset" and gan_path is not None:
97+
raise ValueError("GANs cannot be used with dataset visualization.")
98+
if gan_path is not None and l2_coeff is None:
99+
raise ValueError("When GANs are used, l2_coeff must be set.")
100+
if img_save_path is not None and img_save_path[-1] != "/":
101+
raise ValueError("img_save_path is not a directory, does not end in /")
102+
85103
# Set up imitation-style logging.
86104
custom_logger, log_dir = common_config.setup_logging()
87105
wandb_logging = "wandb" in common["log_format_strs"]
@@ -101,7 +119,8 @@ def interpret(
101119
rew_net = TensorTransitionWrapper(rew_net)
102120
else: # Use GAN
103121
# Combine rew net with GAN.
104-
raise NotImplementedError()
122+
gan = th.load(gan_path, map_location=th.device(device))
123+
rew_net = RewardGeneratorCombo(reward_net=rew_net, generator=gan.generator)
105124

106125
rew_net.eval() # Eval for visualization.
107126

@@ -129,17 +148,17 @@ def interpret(
129148
# Ensure loaded data is FloatTensor and not DoubleTensor.
130149
inputs = inputs.float()
131150
else: # When using GAN.
132-
# Inputs should be some samples of input vectors? Not sure if this is the best
133-
# way to do this, there might be better options.
134-
# The important part is that lucent expects 4D tensors as inputs, so increase
135-
# dimensionality accordingly.
136-
raise NotImplementedError()
151+
# Inputs are GAN samples
152+
samples = gan.sample(limit_num_obs)
153+
inputs = samples[:, :, None, None]
154+
inputs = inputs.to(device)
155+
inputs = inputs.float()
137156

138157
# The model to analyse should be a torch module that takes a single input, which
139158
# should be a torch Tensor.
140159
# In our case this is one of the following:
141160
# - A reward net that has been wrapped, so it accepts transition tensors.
142-
# - A combo of GAN and reward net that accepts latent inputs vectors. (TODO)
161+
# - A combo of GAN and reward net that accepts latent inputs vectors.
143162
model_to_analyse = rew_net
144163
nmf = LayerNMF(
145164
model=model_to_analyse,
@@ -157,43 +176,73 @@ def interpret(
157176
num_features = nmf.channel_dirs.shape[0]
158177
rows, columns = 1, num_features
159178
if pyplot:
160-
fig = plt.figure(figsize=(columns * 2, rows * 2)) # width, height in inches
179+
col_mult = 4 if vis_type == "traditional" else 2
180+
# figsize is width, height in inches
181+
fig = plt.figure(figsize=(columns * col_mult, rows * 2))
161182
else:
162183
fig = None
163184

164185
# Visualize
165186
if vis_type == "traditional":
166-
# List of transforms
167-
transforms = [
168-
transform.jitter(2), # Jitters input by 2 pixel
169-
]
170-
171-
opt_transitions = nmf.vis_traditional(transforms=transforms)
172-
# This gives as an array that optimizes the objectives, in the shape of the
173-
# input which is a transition tensor. However, lucent helpfully transposes the
174-
# output such that the channel dimension is last. Our functions expect channel
175-
# dim before spatial dims, so we need to transpose it back.
176-
opt_transitions = opt_transitions.transpose(0, 3, 1, 2)
177-
# Split the optimized transitions, one for each feature, into separate
178-
# observations and actions. This function only works with torch tensors.
179-
obs, acts, next_obs = tensor_to_transition(th.tensor(opt_transitions))
180-
# obs and next_obs output have channel dim last.
181-
# acts is output as one-hot vector.
187+
188+
if gan_path is None:
189+
# List of transforms
190+
transforms = [
191+
transform.jitter(2), # Jitters input by 2 pixel
192+
]
193+
194+
opt_transitions = nmf.vis_traditional(transforms=transforms)
195+
# This gives as an array that optimizes the objectives, in the shape of the
196+
# input which is a transition tensor. However, lucent helpfully transposes
197+
# the output such that the channel dimension is last. Our functions expect
198+
# channel dim before spatial dims, so we need to transpose it back.
199+
opt_transitions = opt_transitions.transpose(0, 3, 1, 2)
200+
# Split the optimized transitions, one for each feature, into separate
201+
# observations and actions. This function only works with torch tensors.
202+
obs, acts, next_obs = tensor_to_transition(th.tensor(opt_transitions))
203+
# obs and next_obs output have channel dim last.
204+
# acts is output as one-hot vector.
205+
206+
else:
207+
# We do not require the latent vectors to be transformed before optimizing.
208+
# However, we do regularize the L2 norm of latent vectors, to ensure the
209+
# resulting generated images are realistic.
210+
opt_latent = nmf.vis_traditional(
211+
transforms=[],
212+
l2_coeff=l2_coeff,
213+
l2_layer_name="generator_network_latent_vec",
214+
)
215+
# Now, we put the latent vector thru the generator to produce transition
216+
# tensors that we can get observations, actions, etc out of
217+
opt_latent = np.mean(opt_latent, axis=(1, 2))
218+
opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
219+
opt_transitions = gan.generator(opt_latent_th)
220+
obs, acts, next_obs = tensor_to_transition(opt_transitions)
182221

183222
# Set of images, one for each feature, add each to plot
184223
for feature_i in range(next_obs.shape[0]):
185-
sub_img = next_obs[feature_i]
224+
sub_img_obs = obs[feature_i].detach().cpu().numpy()
225+
sub_img_next_obs = next_obs[feature_i].detach().cpu().numpy()
186226
plot_img(
187227
columns,
188228
custom_logger,
189229
feature_i,
190230
fig,
191-
sub_img,
231+
(sub_img_obs, sub_img_next_obs),
192232
pyplot,
193233
rows,
194234
vis_scale,
195235
wandb_logging,
196236
)
237+
if img_save_path is not None:
238+
obs_PIL = array_to_image(sub_img_obs, vis_scale)
239+
obs_PIL.save(img_save_path + f"{feature_i}_obs.png")
240+
next_obs_PIL = array_to_image(sub_img_next_obs, vis_scale)
241+
next_obs_PIL.save(img_save_path + f"{feature_i}_next_obs.png")
242+
custom_logger.log(
243+
f"Saved feature {feature_i} viz in dir {img_save_path}."
244+
)
245+
197246
elif vis_type == "dataset":
198247
for feature_i in range(num_features):
199248
custom_logger.log(f"Feature {feature_i}")
@@ -213,51 +262,93 @@ def interpret(
213262
vis_scale,
214263
wandb_logging,
215264
)
216-
else:
217-
raise ValueError(f"Unknown vis_type: {vis_type}.")
218265

219266
if pyplot:
220267
plt.show()
221268
custom_logger.log("Done with dataset visualization.")
222269

223270

271+
def array_to_image(arr: np.ndarray, scale: int) -> Image:
272+
"""Take numpy array on [0,1] scale, return PIL image."""
273+
return Image.fromarray(np.uint8(arr * 255), mode="RGB").resize(
274+
size=(arr.shape[0] * scale, arr.shape[1] * scale),
275+
resample=Image.NEAREST,
276+
)
277+
278+
224279
def plot_img(
225280
columns: int,
226281
custom_logger: HierarchicalLogger,
227282
feature_i: int,
228283
fig: Optional[matplotlib.figure.Figure],
229-
img: np.ndarray,
284+
img: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
230285
pyplot: bool,
231286
rows: int,
232287
vis_scale: int,
233288
wandb_logging: bool,
234289
):
235-
"""Plot the passed image to pyplot and wandb as appropriate."""
290+
"""Plot the passed image(s) to pyplot and wandb as appropriate."""
236291
_wandb_log(custom_logger, feature_i, img, vis_scale, wandb_logging)
237-
if fig is not None and pyplot:
238-
fig.add_subplot(rows, columns, feature_i + 1)
239-
plt.imshow(img)
292+
if pyplot:
293+
if isinstance(img, tuple):
294+
img_obs = img[0]
295+
img_next_obs = img[1]
296+
fig.add_subplot(rows, columns, 2 * feature_i + 1)
297+
plt.imshow(img_obs)
298+
fig.add_subplot(rows, columns, 2 * feature_i + 2)
299+
plt.imshow(img_next_obs)
300+
else:
301+
fig.add_subplot(rows, columns, feature_i + 1)
302+
plt.imshow(img)
240303

241304

242305
def _wandb_log(
243306
custom_logger: HierarchicalLogger,
244307
feature_i: int,
245-
img: np.ndarray,
308+
img: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
246309
vis_scale: int,
247310
wandb_logging: bool,
248311
):
249312
"""Plot to wandb if wandb logging is enabled."""
250313
if wandb_logging:
251-
p_img = Image.fromarray(np.uint8(img * 255), mode="RGB").resize(
252-
size=(img.shape[0] * vis_scale, img.shape[1] * vis_scale),
253-
resample=Image.NEAREST,
254-
)
255-
wb_img = wandb.Image(p_img, caption=f"Feature {feature_i}")
256-
custom_logger.record(f"feature_{feature_i}", wb_img)
314+
if isinstance(img, tuple):
315+
img_obs = img[0]
316+
img_next_obs = img[1]
317+
# TODO(df): check if I have to dump between these
318+
_wandb_log_(img_obs, vis_scale, feature_i, "obs", custom_logger)
319+
_wandb_log_(img_next_obs, vis_scale, feature_i, "next_obs", custom_logger)
320+
else:
321+
_wandb_log_(img, vis_scale, feature_i, "dataset_vis", custom_logger)
322+
257323
# Can't re-use steps unfortunately, so each feature img gets its own step.
258324
custom_logger.dump(step=feature_i)
259325

260326

327+
def _wandb_log_(
328+
arr: np.ndarray,
329+
scale: int,
330+
feature: int,
331+
img_type: str,
332+
logger: HierarchicalLogger,
333+
) -> None:
334+
"""Log visualized np.ndarray to wandb using given logger.
335+
336+
Args:
337+
- arr: array to turn into image, save.
338+
- scale: ratio by which to scale up the image.
339+
- feature: which number feature is being visualized.
340+
- img_type: "obs" or "next_obs"
341+
- logger: logger to use.
342+
"""
343+
if img_type not in ["obs", "next_obs"]:
344+
err_str = f"img_type should be 'obs' or 'next_obs', but instead is {img_type}"
345+
raise ValueError(err_str)
346+
347+
pil_img = array_to_image(arr, scale)
348+
wb_img = wandb.Image(pil_img, caption=f"Feature {feature}, {img_type}")
349+
logger.record(f"feature_{feature}_{img_type}", wb_img)
350+
351+
261352
def main():
262353
observer = FileStorageObserver(osp.join("output", "sacred", "interpret"))
263354
interpret_ex.observers.append(observer)
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch as th
2+
3+
from reward_preprocessing.common import utils
4+
5+
GAN_TIMESTAMP = "20221104_163134"
6+
MODEL_NUMBER = "13720"
7+
8+
if __name__ == "__main__":
9+
gan_path = (
10+
"/nas/ucb/daniel/gan_test_data_"
11+
+ GAN_TIMESTAMP
12+
+ "/models/model_"
13+
+ MODEL_NUMBER
14+
+ ".torch"
15+
)
16+
device = "cuda" if th.cuda.is_available() else "cpu"
17+
gan = th.load(gan_path, map_location=th.device(device))
18+
samples, _ = gan.get_training_results()
19+
utils.visualize_samples(samples.detach().cpu().numpy(), gan.folder)

0 commit comments

Comments
 (0)