Skip to content

Commit 215b80a

Browse files
committed
Make dataset viz work
1 parent 481bc77 commit 215b80a

File tree

3 files changed

+33
-18
lines changed

3 files changed

+33
-18
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,30 +129,23 @@ def visualize_samples(samples: np.ndarray, save_dir):
129129
to turn act into a numpy array, before saving it.
130130
"""
131131
for i, transition in enumerate(samples):
132-
num_acts = transition.shape[0] - 6
133-
s = transition[0:3, :, :]
132+
s, act, s_ = ndarray_to_transition(transition)
134133
s = process_image_array(s)
135-
act = transition[3 : 3 + num_acts, :, :]
136-
s_ = transition[3 + num_acts : transition.shape[0], :, :]
137134
s_ = process_image_array(s_)
138-
act_slim_mean = np.mean(act, axis=(1, 2))
139-
act_slim_max = np.max(np.abs(act), axis=(1, 2))
140135
s_img = PIL.Image.fromarray(s)
141136
s__img = PIL.Image.fromarray(s_)
142137
(Path(save_dir) / str(i)).mkdir()
143138
s_img.save(Path(save_dir) / str(i) / "first_obs.png")
144139
s__img.save(Path(save_dir) / str(i) / "second_obs.png")
145-
np.save(Path(save_dir) / str(i) / "act_vec_mean.npy", act_slim_mean)
146-
np.save(Path(save_dir) / str(i) / "act_vec_max.npy", act_slim_max)
140+
np.save(Path(save_dir) / str(i) / "act.npy", act)
147141

148142

149143
def process_image_array(img: np.ndarray) -> np.ndarray:
150144
"""Process a numpy array for feeding into PIL.Image.fromarray."""
151145
up_multiplied = img * 255
152146
clipped = np.clip(up_multiplied, 0, 255)
153147
cast = clipped.astype(np.uint8)
154-
transposed = np.transpose(cast, axes=(1, 2, 0))
155-
return transposed
148+
return cast
156149

157150

158151
def tensor_to_transition(
@@ -178,6 +171,22 @@ def tensor_to_transition(
178171
return obs_proc, act_proc, next_obs_proc
179172

180173

174+
def ndarray_to_transition(
175+
np_trans: np.ndarray,
176+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
177+
"""Turn a numpy transition tensor into three bona fide transitions."""
178+
if len(np_trans.shape) != 3:
179+
raise ValueError("ndarray_to_transition assumes input has shape of length 3")
180+
boosted_np_trans = np_trans[None, :, :, :]
181+
th_trans = th.from_numpy(boosted_np_trans)
182+
th_obs, th_act, th_next_obs = tensor_to_transition(th_trans)
183+
np_obs, np_act, np_next_obs = map(
184+
lambda th_result: th_result[0].detach().cpu().numpy(),
185+
(th_obs, th_act, th_next_obs),
186+
)
187+
return np_obs, np_act, np_next_obs
188+
189+
181190
def process_image_tensor(obs: th.Tensor) -> th.Tensor:
182191
"""Take a GAN image and processes it for use in a reward net."""
183192
clipped_obs = th.clamp(obs, 0, 1)

src/reward_preprocessing/interpret.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TensorTransitionWrapper,
1919
array_to_image,
2020
log_img_wandb,
21+
ndarray_to_transition,
2122
rollouts_to_dataloader,
2223
tensor_to_transition,
2324
)
@@ -179,7 +180,7 @@ def interpret(
179180
# Combine rew net with GAN.
180181
gan = th.load(gan_path, map_location=th.device(device))
181182
model_to_analyse = RewardGeneratorCombo(
182-
reward_net=rew_net, generator=gan.generator
183+
rew_net=rew_net, generator=gan.generator
183184
)
184185

185186
model_to_analyse.eval() # Eval for visualization.
@@ -361,19 +362,21 @@ def interpret(
361362
for feature_i in range(num_features):
362363
custom_logger.log(f"Feature {feature_i}")
363364

364-
img, indices = nmf.vis_dataset_thumbnail(
365+
np_trans_tens, indices = nmf.vis_dataset_thumbnail(
365366
feature=feature_i, num_mult=4, expand_mult=1
366367
)
367368

369+
obs, _, next_obs = ndarray_to_transition(np_trans_tens)
370+
368371
_log_single_transition_wandb(
369-
custom_logger, feature_i, img, vis_scale, wandb_logging
372+
custom_logger, feature_i, (obs, next_obs), vis_scale, wandb_logging
370373
)
371374
_plot_img(
372375
columns,
373376
feature_i,
374377
num_features,
375378
fig,
376-
img,
379+
(obs, next_obs),
377380
rows,
378381
)
379382

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,18 +340,21 @@ def pad_obses(self, *, expand_mult=1):
340340
% 2
341341
) # Checkered pattern.
342342
self.padded_obses = self.padded_obses * 0.25 + 0.75 # Adjust color.
343-
self.padded_obses = self.padded_obses.astype(self.model_inputs_full.dtype)
343+
self.padded_obses = self.padded_obses.astype(
344+
self.model_inputs_full.detach().cpu().numpy().dtype
345+
)
344346
# Add dims for batch and channel.
345347
self.padded_obses = self.padded_obses[None, None, ...]
346348
# Repeat for correct number of images.
347349
self.padded_obses = self.padded_obses.repeat(
348350
self.model_inputs_full.shape[0], axis=0
349351
)
350352
# Repeat channel dimension.
351-
self.padded_obses = self.padded_obses.repeat(3, axis=1)
353+
num_channels = self.model_inputs_full.shape[1]
354+
self.padded_obses = self.padded_obses.repeat(num_channels, axis=1)
352355
self.padded_obses[
353356
:, :, self.pad_h : -self.pad_h, self.pad_w : -self.pad_w
354-
] = self.model_inputs_full
357+
] = (self.model_inputs_full.detach().cpu().numpy())
355358

356359
def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
357360
left_h = self.pad_h + (pos_h - 0.5 * expand_mult) * self.patch_h
@@ -468,7 +471,7 @@ def vis_dataset_thumbnail(
468471
acts_single = acts_feature[
469472
range(acts_feature.shape[0]), pos_indices[0], pos_indices[1]
470473
]
471-
# Sort the activations in descending order and take the num_mult**2 strongest.
474+
# Sort the activations in descending order and take the num_mult**2 strongest
472475
# activations.
473476
obs_indices = np.argsort(-acts_single, axis=0)[: num_mult**2]
474477
# Coordinates of the strongest activation in each observation.

0 commit comments

Comments
 (0)