Skip to content

Commit 6e1ac23

Browse files
authored
Merge pull request #28 from HumanCompatibleAI/fix_gan_vis
Fix gan vis
2 parents 8781b1b + 21175de commit 6e1ac23

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

src/reward_preprocessing/interpret.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from imitation.util.logger import HierarchicalLogger
88
from lucent.modelzoo.util import get_model_layers
99
from lucent.optvis import transform
10+
from lucent.optvis.param.spatial import pixel_image
1011
import matplotlib
1112
from matplotlib import pyplot as plt
1213
import numpy as np
@@ -264,14 +265,32 @@ def interpret(
264265
# We do not require the latent vectors to be transformed before optimizing.
265266
# However, we do regularize the L2 norm of latent vectors, to ensure the
266267
# resulting generated images are realistic.
268+
z_dim_as_expected = (
269+
isinstance(gan.z_dim, tuple)
270+
and len(gan.z_dim) == 1
271+
and isinstance(gan.z_dim[0], int)
272+
)
273+
if not z_dim_as_expected:
274+
error_string = (
275+
"interpret.py expects the GAN's latent input shape to "
276+
+ f"be a tuple of length 1, instead it is {gan.z_dim}."
277+
)
278+
raise TypeError(error_string)
279+
# ensure visualization doesn't treat the latent vector as an image.
280+
latent_shape = (num_features, gan.z_dim[0], 1, 1)
281+
282+
def param_f():
283+
return pixel_image(shape=latent_shape)
284+
267285
opt_latent = nmf.vis_traditional(
268286
transforms=[],
269287
l2_coeff=l2_coeff,
270288
l2_layer_name="generator_network_latent_vec",
289+
param_f=param_f,
271290
)
272291
# Now, we put the latent vector thru the generator to produce transition
273292
# tensors that we can get observations, actions, etc out of
274-
opt_latent = np.mean(opt_latent, axis=(1, 2))
293+
opt_latent = np.squeeze(opt_latent)
275294
opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
276295
opt_transitions = gan.generator(opt_latent_th)
277296
obs, acts, next_obs = tensor_to_transition(opt_transitions)

src/reward_preprocessing/vis/reward_vis.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Port of lucid.scratch.rl_util to PyTorch. APL2.0 licensed."""
22
from functools import reduce
33
import logging
4-
from typing import Callable, Dict, List, Optional, Union
4+
from typing import Callable, Dict, List, Optional, Tuple, Union
55

66
from lucent.optvis.objectives import handle_batch, wrap_objective
77
import lucent.optvis.param as param
@@ -253,6 +253,9 @@ def vis_traditional(
253253
transforms: List[Callable[[th.Tensor], th.Tensor]] = [transform.jitter(2)],
254254
l2_coeff: float = 0.0,
255255
l2_layer_name: Optional[str] = None,
256+
param_f: Optional[
257+
Callable[[], Tuple[th.Tensor, Callable[[], th.Tensor]]]
258+
] = None,
256259
) -> np.ndarray:
257260
if feature_list is None:
258261
# Feature dim is at index 1
@@ -291,13 +294,15 @@ def vis_traditional(
291294
obj -= l2_objective(l2_layer_name, l2_coeff)
292295
input_shape = tuple(self.model_inputs_preprocess.shape[1:])
293296

294-
def param_f():
295-
return param.image(
296-
channels=input_shape[0],
297-
h=input_shape[1],
298-
w=input_shape[2],
299-
batch=len(feature_list),
300-
)
297+
if param_f is None:
298+
299+
def param_f():
300+
return param.image(
301+
channels=input_shape[0],
302+
h=input_shape[1],
303+
w=input_shape[2],
304+
batch=len(feature_list),
305+
)
301306

302307
logging.info(f"Performing vis_traditional with transforms: {transforms}")
303308

0 commit comments

Comments
 (0)