Skip to content

Commit f6351de

Browse files
committed
Squeeze more concisely
1 parent 890152f commit f6351de

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

src/reward_preprocessing/interpret.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,7 @@ def param_f():
290290
)
291291
# Now, we put the latent vector thru the generator to produce transition
292292
# tensors that we can get observations, actions, etc out of
293-
assert opt_latent.shape[1] == 1
294-
assert opt_latent.shape[2] == 1
295-
squeeze_shape = [opt_latent.shape[0], opt_latent.shape[3]]
296-
opt_latent = opt_latent.reshape(squeeze_shape)
293+
opt_latent = np.squeeze(opt_latent, axis=(1, 2))
297294
# ^ squeeze out extraneous "height" and "width" dimensions
298295
opt_latent_th = th.from_numpy(opt_latent).to(th.device(device))
299296
opt_transitions = gan.generator(opt_latent_th)

0 commit comments

Comments
 (0)