Skip to content

Commit 68d7275

Browse files
authored
Merge pull request #21 from HumanCompatibleAI/log_max_acts
Log max acts
2 parents 190e0cc + ee594f6 commit 68d7275

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/reward_preprocessing/common/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,15 @@ def visualize_samples(samples: np.ndarray, save_dir):
135135
act = transition[3 : 3 + num_acts, :, :]
136136
s_ = transition[3 + num_acts : transition.shape[0], :, :]
137137
s_ = process_image_array(s_)
138-
act_slim = np.mean(act, axis=(1, 2))
138+
act_slim_mean = np.mean(act, axis=(1, 2))
139+
act_slim_max = np.max(np.abs(act), axis=(1, 2))
139140
s_img = PIL.Image.fromarray(s)
140141
s__img = PIL.Image.fromarray(s_)
141142
(Path(save_dir) / str(i)).mkdir()
142143
s_img.save(Path(save_dir) / str(i) / "first_obs.png")
143144
s__img.save(Path(save_dir) / str(i) / "second_obs.png")
144-
np.save(Path(save_dir) / str(i) / "act_vec.npy", act_slim)
145+
np.save(Path(save_dir) / str(i) / "act_vec_mean.npy", act_slim_mean)
146+
np.save(Path(save_dir) / str(i) / "act_vec_max.npy", act_slim_max)
145147

146148

147149
def process_image_array(img: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)