Skip to content

Commit 8f9d370

Browse files
committed
Fix dots and dists gen, clarify code, add helper file
1 parent 1d30a45 commit 8f9d370

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

src/reward_preprocessing/common/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,26 @@ def flatten_trajectories_with_rew_double_info(
328328
parts = {key: [] for key in keys}
329329
long_trajs = filter(lambda traj: len(traj.acts) > 2, trajectories)
330330
for traj in long_trajs:
331+
# discard first and last action
331332
parts["acts"].append(traj.acts[1:-1])
333+
# discard first observation (with first action), as well as the second-last
334+
# and last observations (which go with the last action)
332335
parts["obs"].append(traj.obs[1:-2])
336+
# discard first observation (which can't be a next_obs), second observation
337+
# (which goes with the first action), and last observation (which goes with
338+
# the last action)
333339
parts["next_obs"].append(traj.obs[2:-1])
340+
# make enough dones
334341
dones = np.zeros(len(traj.acts) - 2, dtype=bool)
335342
parts["dones"].append(dones)
343+
# rews match actions
336344
parts["rews"].append(traj.rews[1:-1])
337345

338346
if traj.infos is None:
339347
infos = np.array([{}] * (len(traj) - 1))
340348
next_infos = np.array([{}] * (len(traj) - 1))
341349
else:
350+
# index 0 of traj.infos is associated with index 1 of traj.obs
342351
infos = traj.infos[:-2]
343352
next_infos = traj.infos[1:-1]
344353

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
3+
import numpy as np
4+
5+
DATA_PATH = (
6+
"/nas/ucb/daniel/nas_reward_function_interpretability/"
7+
+ "dots-and-dists-64-1e6-2023-11.npz"
8+
)
9+
SAVE_PATH = (
10+
"/nas/ucb/daniel/nas_reward_function_interpretability/"
11+
+ "dots-and-dists-64-1e6-2023-11-binarized.npz"
12+
)
13+
NON_ZERO_FRAC = 0.0093
14+
15+
traj_data = np.load(DATA_PATH, allow_pickle=True)
16+
17+
rews_sorted = sorted(traj_data["rews"])
18+
low_avg_dist = rews_sorted[int(NON_ZERO_FRAC * len(rews_sorted))]
19+
new_rews = list(map(lambda rew: 10.0 if rew < low_avg_dist else 0.0, traj_data["rews"]))
20+
21+
new_traj_data = {
22+
"obs": traj_data["obs"],
23+
"acts": traj_data["acts"],
24+
"infos": traj_data["infos"],
25+
"terminal": traj_data["terminal"],
26+
"rews": np.array(new_rews).astype(np.float32),
27+
"indices": traj_data["indices"],
28+
}
29+
30+
tmp_path = SAVE_PATH + ".tmp"
31+
with open(tmp_path, "wb") as f:
32+
np.savez_compressed(f, **new_traj_data)
33+
34+
os.replace(tmp_path, SAVE_PATH)
35+
print("Saved binarized trajectory")

src/reward_preprocessing/scripts/config/train_probe.py

+7
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,10 @@ def sort_distances():
114114
def exp_distances():
115115
attr_func = lambda vec: [20**x for x in vec] # noqa: E731
116116
locals()
117+
118+
119+
@train_probe_ex.named_config
120+
def sum_distances():
121+
attr_func = sum
122+
attr_dim = 1
123+
locals()

src/reward_preprocessing/scripts/gen_dots_and_dists.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,22 @@ def generate_simple_trajectories(
3131
obs_list = []
3232
infos_list = []
3333
avg_distances = []
34-
for i in range(num_transitions):
34+
for i in range(num_transitions + 1):
3535
data, avg_distance, distances = generate_transition(
3636
number_pairs, circle_radius, size, colors, weights
3737
)
3838
obs_list.append(data)
3939
infos_list.append({"distances": distances})
4040
avg_distances.append(avg_distance)
4141

42-
# Duplicate last observation, since there is always a final next_obs.
43-
obs_list.append(obs_list[-1].copy())
42+
# Drop the first element of the infos list, since the first observation shouldn't
43+
# come with an info dict (see flatten_trajectories_with_rew_double_info in
44+
# common/utils.py)
45+
infos_list = infos_list[1:]
46+
47+
# Drop the last element of avg_distances, since the last observation is the next_obs
48+
# of the final transition, and rewards are associated with obs, not next_obs
49+
avg_distances = avg_distances[:-1]
4450

4551
condensed = {
4652
"obs": np.array(obs_list).astype(np.uint8),

0 commit comments

Comments
 (0)