Skip to content

Commit 21175de

Browse files
committed
Merge branch 'main' into fix_gan_vis
2 parents b87b225 + 8781b1b commit 21175de

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/reward_preprocessing/common/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ def visualize_samples(samples: np.ndarray, save_dir):
141141

142142

143143
def process_image_array(img: np.ndarray) -> np.ndarray:
144-
"""Process a numpy array for feeding into PIL.Image.fromarray."""
144+
"""Process a numpy array for feeding into PIL.Image.fromarray.
145+
146+
Should already be in (h,w,c) format.
147+
"""
145148
up_multiplied = img * 255
146149
clipped = np.clip(up_multiplied, 0, 255)
147150
cast = clipped.astype(np.uint8)

src/reward_preprocessing/interpret.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,13 @@ def param_f():
381381
for feature_i in range(num_features):
382382
custom_logger.log(f"Feature {feature_i}")
383383

384-
np_trans_tens, indices = nmf.vis_dataset_thumbnail(
384+
dataset_thumbnails, indices = nmf.vis_dataset_thumbnail(
385385
feature=feature_i, num_mult=4, expand_mult=1
386386
)
387387

388+
# remove opacity channel from dataset thumbnails
389+
np_trans_tens = dataset_thumbnails[:-1, :, :]
390+
388391
obs, _, next_obs = ndarray_to_transition(np_trans_tens)
389392

390393
_log_single_transition_wandb(
@@ -399,6 +402,15 @@ def param_f():
399402
rows,
400403
)
401404

405+
if img_save_path is not None:
406+
obs_PIL = array_to_image(obs, vis_scale)
407+
obs_PIL.save(img_save_path + f"{feature_i}_obs.png")
408+
next_obs_PIL = array_to_image(next_obs, vis_scale)
409+
next_obs_PIL.save(img_save_path + f"{feature_i}_next_obs.png")
410+
custom_logger.log(
411+
f"Saved feature {feature_i} viz in dir {img_save_path}."
412+
)
413+
402414
if pyplot:
403415
plt.show()
404416
custom_logger.log("Done with visualization.")

0 commit comments

Comments
 (0)