Skip to content

Commit 68d8cc7

Browse files
authored
Merge pull request #30 from HumanCompatibleAI/fix_latent_squeeze
Fix latent squeezing
2 parents 6e1ac23 + f6351de commit 68d8cc7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/reward_preprocessing/interpret.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def param_f():
290290
)
291291
# Now, we put the latent vector thru the generator to produce transition
292292
# tensors that we can get observations, actions, etc out of
293-
opt_latent = np.squeeze(opt_latent)
293+
opt_latent = np.squeeze(opt_latent, axis=(1, 2))
294+
# ^ squeeze out extraneous "height" and "width" dimensions
294295
opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
295296
opt_transitions = gan.generator(opt_latent_th)
296297
obs, acts, next_obs = tensor_to_transition(opt_transitions)

0 commit comments

Comments
 (0)