Skip to content

Commit 8781b1b

Browse files
authored
Merge pull request #27 from HumanCompatibleAI/dataset_viz_fixes
Dataset viz fixes
2 parents 8cf91c3 + 36c7ee4 commit 8781b1b

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,30 +129,26 @@ def visualize_samples(samples: np.ndarray, save_dir):
129129
to turn act into a numpy array, before saving it.
130130
"""
131131
for i, transition in enumerate(samples):
132-
num_acts = transition.shape[0] - 6
133-
s = transition[0:3, :, :]
132+
s, act, s_ = ndarray_to_transition(transition)
134133
s = process_image_array(s)
135-
act = transition[3 : 3 + num_acts, :, :]
136-
s_ = transition[3 + num_acts : transition.shape[0], :, :]
137134
s_ = process_image_array(s_)
138-
act_slim_mean = np.mean(act, axis=(1, 2))
139-
act_slim_max = np.max(np.abs(act), axis=(1, 2))
140135
s_img = PIL.Image.fromarray(s)
141136
s__img = PIL.Image.fromarray(s_)
142137
(Path(save_dir) / str(i)).mkdir()
143138
s_img.save(Path(save_dir) / str(i) / "first_obs.png")
144139
s__img.save(Path(save_dir) / str(i) / "second_obs.png")
145-
np.save(Path(save_dir) / str(i) / "act_vec_mean.npy", act_slim_mean)
146-
np.save(Path(save_dir) / str(i) / "act_vec_max.npy", act_slim_max)
140+
np.save(Path(save_dir) / str(i) / "act.npy", act)
147141

148142

149143
def process_image_array(img: np.ndarray) -> np.ndarray:
150-
"""Process a numpy array for feeding into PIL.Image.fromarray."""
144+
"""Process a numpy array for feeding into PIL.Image.fromarray.
145+
146+
Should already be in (h,w,c) format.
147+
"""
151148
up_multiplied = img * 255
152149
clipped = np.clip(up_multiplied, 0, 255)
153150
cast = clipped.astype(np.uint8)
154-
transposed = np.transpose(cast, axes=(1, 2, 0))
155-
return transposed
151+
return cast
156152

157153

158154
def tensor_to_transition(
@@ -178,6 +174,22 @@ def tensor_to_transition(
178174
return obs_proc, act_proc, next_obs_proc
179175

180176

177+
def ndarray_to_transition(
178+
np_trans: np.ndarray,
179+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
180+
"""Turn a numpy transition tensor into three bona fide transitions."""
181+
if len(np_trans.shape) != 3:
182+
raise ValueError("ndarray_to_transition assumes input has shape of length 3")
183+
boosted_np_trans = np_trans[None, :, :, :]
184+
th_trans = th.from_numpy(boosted_np_trans)
185+
th_obs, th_act, th_next_obs = tensor_to_transition(th_trans)
186+
np_obs, np_act, np_next_obs = map(
187+
lambda th_result: th_result[0].detach().cpu().numpy(),
188+
(th_obs, th_act, th_next_obs),
189+
)
190+
return np_obs, np_act, np_next_obs
191+
192+
181193
def process_image_tensor(obs: th.Tensor) -> th.Tensor:
182194
"""Take a GAN image and processes it for use in a reward net."""
183195
clipped_obs = th.clamp(obs, 0, 1)

src/reward_preprocessing/interpret.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TensorTransitionWrapper,
1919
array_to_image,
2020
log_img_wandb,
21+
ndarray_to_transition,
2122
rollouts_to_dataloader,
2223
tensor_to_transition,
2324
)
@@ -361,22 +362,36 @@ def interpret(
361362
for feature_i in range(num_features):
362363
custom_logger.log(f"Feature {feature_i}")
363364

364-
img, indices = nmf.vis_dataset_thumbnail(
365+
dataset_thumbnails, indices = nmf.vis_dataset_thumbnail(
365366
feature=feature_i, num_mult=4, expand_mult=1
366367
)
367368

369+
# remove opacity channel from dataset thumbnails
370+
np_trans_tens = dataset_thumbnails[:-1, :, :]
371+
372+
obs, _, next_obs = ndarray_to_transition(np_trans_tens)
373+
368374
_log_single_transition_wandb(
369-
custom_logger, feature_i, img, vis_scale, wandb_logging
375+
custom_logger, feature_i, (obs, next_obs), vis_scale, wandb_logging
370376
)
371377
_plot_img(
372378
columns,
373379
feature_i,
374380
num_features,
375381
fig,
376-
img,
382+
(obs, next_obs),
377383
rows,
378384
)
379385

386+
if img_save_path is not None:
387+
obs_PIL = array_to_image(obs, vis_scale)
388+
obs_PIL.save(img_save_path + f"{feature_i}_obs.png")
389+
next_obs_PIL = array_to_image(next_obs, vis_scale)
390+
next_obs_PIL.save(img_save_path + f"{feature_i}_next_obs.png")
391+
custom_logger.log(
392+
f"Saved feature {feature_i} viz in dir {img_save_path}."
393+
)
394+
380395
if pyplot:
381396
plt.show()
382397
custom_logger.log("Done with visualization.")

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,18 +340,21 @@ def pad_obses(self, *, expand_mult=1):
340340
% 2
341341
) # Checkered pattern.
342342
self.padded_obses = self.padded_obses * 0.25 + 0.75 # Adjust color.
343-
self.padded_obses = self.padded_obses.astype(self.model_inputs_full.dtype)
343+
self.padded_obses = self.padded_obses.astype(
344+
self.model_inputs_full.detach().cpu().numpy().dtype
345+
)
344346
# Add dims for batch and channel.
345347
self.padded_obses = self.padded_obses[None, None, ...]
346348
# Repeat for correct number of images.
347349
self.padded_obses = self.padded_obses.repeat(
348350
self.model_inputs_full.shape[0], axis=0
349351
)
350352
# Repeat channel dimension.
351-
self.padded_obses = self.padded_obses.repeat(3, axis=1)
353+
num_channels = self.model_inputs_full.shape[1]
354+
self.padded_obses = self.padded_obses.repeat(num_channels, axis=1)
352355
self.padded_obses[
353356
:, :, self.pad_h : -self.pad_h, self.pad_w : -self.pad_w
354-
] = self.model_inputs_full
357+
] = (self.model_inputs_full.detach().cpu().numpy())
355358

356359
def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
357360
left_h = self.pad_h + (pos_h - 0.5 * expand_mult) * self.patch_h
@@ -468,7 +471,7 @@ def vis_dataset_thumbnail(
468471
acts_single = acts_feature[
469472
range(acts_feature.shape[0]), pos_indices[0], pos_indices[1]
470473
]
471-
# Sort the activations in descending order and take the num_mult**2 strongest.
474+
# Sort the activations in descending order and take the num_mult**2 strongest
472475
# activations.
473476
obs_indices = np.argsort(-acts_single, axis=0)[: num_mult**2]
474477
# Coordinates of the strongest activation in each observation.

0 commit comments

Comments
 (0)