We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6e1ac23 commit d64ff57Copy full SHA for d64ff57
src/reward_preprocessing/interpret.py
@@ -290,7 +290,9 @@ def param_f():
290
)
291
# Now, we put the latent vector thru the generator to produce transition
292
# tensors that we can get observations, actions, etc out of
293
- opt_latent = np.squeeze(opt_latent)
+ squeeze_shape = [opt_latent.shape[0], opt_latent.shape[3]]
294
+ opt_latent = opt_latent.reshape(squeeze_shape)
295
+ # ^ squeeze out extraneous "height" and "width" dimensions
296
opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
297
opt_transitions = gan.generator(opt_latent_th)
298
obs, acts, next_obs = tensor_to_transition(opt_transitions)
0 commit comments