Skip to content

Commit 22c2faa

Browse files
committed
Add helper script
1 parent ed1eb92 commit 22c2faa

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ python -m reward_preprocessing.interpret print_config
9494
- `policies`: RL policies for training experts with train_rl.
9595
- `preprocessing` Reward preprocessing / reward shaping code.
9696
- `scripts`: All scripts that are not the main scripts of the projects. Helpers and scripts that produce artifacts that are used by the main script. Everything here should either be an executable file or a config for one.
97-
- `helpers`: Helper scripts that are bash executables.
97+
- `helpers`: Helper scripts that are bash executables or python scripts that are not full sacred experiments.
9898
- `trainers`: Our additions to the suite of reward learning algorithms available in imitation. Currently this contains the trainer for training reward nets with supervised learning.
9999
- `vis`: Visualization code for interpreting reward functions.
100100
- `interpret.py`: The main script that provides the functionality for this project.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Fix saved trajectory format from that one time that I saved them wrong."""
2+
import numpy as np
3+
4+
path = "/home/pavel/out/interpret/expert-rollouts/procgen-gm/005/fixed-coin_1000.2k.npz"
5+
data = np.load(path, allow_pickle=True)
6+
7+
# Observations need to be fixed
8+
observations = data["obs"]
9+
10+
indices = data["indices"]
11+
traj_list = []
12+
for i in range(len(indices)):
13+
if i == 0:
14+
start = 0
15+
else:
16+
start = indices[i - 1]
17+
end = indices[i]
18+
# + 1 because we also want to include the last next_obs
19+
obs = observations[start : end + 1]
20+
traj_list.append(obs)
21+
# Also add the last trajectory
22+
traj_list.append(observations[indices[-1] :])
23+
24+
# Concatenate them together, duplicates and all
25+
new_observations = np.concatenate(traj_list, axis=0)
26+
27+
# Sanity check
28+
assert (
29+
np.cumsum([len(traj) - 1 for traj in traj_list[:-1]]) == np.array(indices)
30+
).all()
31+
32+
new_dict = {
33+
"obs": new_observations,
34+
"acts": data["acts"],
35+
"infos": data["infos"],
36+
"terminal": data["terminal"],
37+
"rews": data["rews"],
38+
"indices": data["indices"],
39+
}
40+
41+
# Update path name
42+
split = path.split(".")
43+
split[-2] += "_fixed"
44+
save_path = ".".join(split)
45+
46+
# Save fixed data
47+
with open(save_path, "wb") as f:
48+
np.savez_compressed(f, **new_dict)

0 commit comments

Comments
 (0)