|
7 | 7 | from imitation.util.logger import HierarchicalLogger
|
8 | 8 | from lucent.modelzoo.util import get_model_layers
|
9 | 9 | from lucent.optvis import transform
|
| 10 | +from lucent.optvis.param.spatial import pixel_image |
10 | 11 | import matplotlib
|
11 | 12 | from matplotlib import pyplot as plt
|
12 | 13 | import numpy as np
|
@@ -264,14 +265,32 @@ def interpret(
|
264 | 265 | # We do not require the latent vectors to be transformed before optimizing.
|
265 | 266 | # However, we do regularize the L2 norm of latent vectors, to ensure the
|
266 | 267 | # resulting generated images are realistic.
|
| 268 | + z_dim_as_expected = ( |
| 269 | + isinstance(gan.z_dim, tuple) |
| 270 | + and len(gan.z_dim) == 1 |
| 271 | + and isinstance(gan.z_dim[0], int) |
| 272 | + ) |
| 273 | + if not z_dim_as_expected: |
| 274 | + error_string = ( |
| 275 | + "interpret.py expects the GAN's latent input shape to " |
| 276 | + + f"be a tuple of length 1, instead it is {gan.z_dim}." |
| 277 | + ) |
| 278 | + raise TypeError(error_string) |
| 279 | + # ensure visualization doesn't treat the latent vector as an image. |
| 280 | + latent_shape = (num_features, gan.z_dim[0], 1, 1) |
| 281 | + |
| 282 | + def param_f(): |
| 283 | + return pixel_image(shape=latent_shape) |
| 284 | + |
267 | 285 | opt_latent = nmf.vis_traditional(
|
268 | 286 | transforms=[],
|
269 | 287 | l2_coeff=l2_coeff,
|
270 | 288 | l2_layer_name="generator_network_latent_vec",
|
| 289 | + param_f=param_f, |
271 | 290 | )
|
272 | 291 | # Now, we put the latent vector thru the generator to produce transition
|
273 | 292 | # tensors that we can get observations, actions, etc out of
|
274 |
| - opt_latent = np.mean(opt_latent, axis=(1, 2)) |
| 293 | + opt_latent = np.squeeze(opt_latent) |
275 | 294 | opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
|
276 | 295 | opt_transitions = gan.generator(opt_latent_th)
|
277 | 296 | obs, acts, next_obs = tensor_to_transition(opt_transitions)
|
|
0 commit comments