Skip to content

Commit 2dffdb1

Browse files
committed
Fix treatment of dataset visualizations
1 parent 215b80a commit 2dffdb1

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/reward_preprocessing/interpret.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,14 @@ def interpret(
362362
for feature_i in range(num_features):
363363
custom_logger.log(f"Feature {feature_i}")
364364

365-
np_trans_tens, indices = nmf.vis_dataset_thumbnail(
365+
dataset_thumbnails, indices = nmf.vis_dataset_thumbnail(
366366
feature=feature_i, num_mult=4, expand_mult=1
367367
)
368368

369+
# remove opacity channel from dataset thumbnails
370+
num_channels_total = dataset_thumbnails.shape[0]
371+
np_trans_tens = dataset_thumbnails[0 : num_channels_total - 1, :, :]
372+
369373
obs, _, next_obs = ndarray_to_transition(np_trans_tens)
370374

371375
_log_single_transition_wandb(
@@ -380,6 +384,15 @@ def interpret(
380384
rows,
381385
)
382386

387+
if img_save_path is not None:
388+
obs_PIL = array_to_image(obs, vis_scale)
389+
obs_PIL.save(img_save_path + f"{feature_i}_obs.png")
390+
next_obs_PIL = array_to_image(next_obs, vis_scale)
391+
next_obs_PIL.save(img_save_path + f"{feature_i}_next_obs.png")
392+
custom_logger.log(
393+
f"Saved feature {feature_i} viz in dir {img_save_path}."
394+
)
395+
383396
if pyplot:
384397
plt.show()
385398
custom_logger.log("Done with visualization.")

0 commit comments

Comments
 (0)