Skip to content

Commit 5e68a35

Browse files
authored
Merge pull request #34 from HumanCompatibleAI/autoreset-pref-comp
Preference Comparisons with Updated AutoResetWrapper
2 parents d4ae733 + ab7de7d commit 5e68a35

File tree

4 files changed

+120
-21
lines changed

4 files changed

+120
-21
lines changed

requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ numpy==1.21.2
88
gym==0.21
99
tqdm==4.62.2
1010
wandb==0.12.1
11-
seals==0.1.5
11+
# seals==0.1.5
12+
# Temporarily use this commit in seals for updates to AutoResetWrapper. Switch back to official version once seals releases a new version.
13+
git+https://github.com/HumanCompatibleAI/seals.git@de298732cda150b18af699e6816fcf42f2fc674f
1214
torch-lucent==0.1.8
1315
jupyter==1.0.0
1416
git+https://github.com/ejnnr/mazelab.git@3042551

src/reward_preprocessing/procgen.py

+48-20
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,32 @@ def supported_procgen_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
1414

1515

1616
def make_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
17-
env = AutoResetWrapper(gym.make(procgen_env_id, **make_env_kwargs))
17+
"""Make procgen with auto reset. Final observation is not fixed.
18+
19+
That means the final observation will be a duplicate of the second to last."""
20+
env = AutoResetWrapper(
21+
gym.make(procgen_env_id, **make_env_kwargs), discard_terminal_observation=False
22+
)
23+
return env
24+
25+
26+
def make_fin_obs_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
27+
"""Make procgen with auto reset and fixed final observation."""
28+
# The order of the wrappers matters here. Final obs wrapper must be applied first,
29+
# then auto reset wrapper. This is because the final obs wrapper depends on the
30+
# done signal, on order to fix the final observation of an episode. The auto reset
31+
# wrapper will reset the done signal to False for the original episode end.
32+
env = AutoResetWrapper(
33+
ProcgenFinalObsWrapper(
34+
gym.make(procgen_env_id, **make_env_kwargs),
35+
),
36+
discard_terminal_observation=False,
37+
)
1838
return env
1939

2040

2141
def make_fin_obs_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
42+
"""Make procgen with fixed final observation."""
2243
env = ProcgenFinalObsWrapper(gym.make(procgen_env_id, **make_env_kwargs))
2344
return env
2445

@@ -37,29 +58,36 @@ def local_name_fin_obs(gym_spec: gym.envs.registration.EnvSpec) -> str:
3758
return "-".join(split_str + [version])
3859

3960

61+
def local_name_fin_obs_autoreset(gym_spec: gym.envs.registration.EnvSpec) -> str:
62+
split_str = gym_spec.id.split("-")
63+
version = split_str[-1]
64+
split_str[-1] = "final-obs-autoreset"
65+
return "-".join(split_str + [version])
66+
67+
4068
def register_procgen_envs(
4169
gym_procgen_env_specs: Iterable[gym.envs.registration.EnvSpec],
4270
) -> None:
4371

44-
for gym_spec in gym_procgen_env_specs:
45-
gym.register(
46-
id=local_name_autoreset(gym_spec),
47-
entry_point="reward_preprocessing.procgen:make_auto_reset_procgen",
48-
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
49-
kwargs=dict(procgen_env_id=gym_spec.id),
50-
)
51-
52-
# There are no envs that have both autoreset and final obs wrappers.
53-
# fin-obs would only affect the terminal_observation in the info dict, if it were
54-
# to be wrapped by an AutoResetWrapper. Since, at the moment, we don't use the
55-
# terminal_observation in the info dict, there is no point to combining them.
56-
for gym_spec in gym_procgen_env_specs:
57-
gym.register(
58-
id=local_name_fin_obs(gym_spec),
59-
entry_point="reward_preprocessing.procgen:make_fin_obs_procgen",
60-
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
61-
kwargs=dict(procgen_env_id=gym_spec.id),
62-
)
72+
to_register = [
73+
# Auto reset with original final observation behavior.
74+
(local_name_autoreset, "reward_preprocessing.procgen:make_auto_reset_procgen"),
75+
# Variable-length procgen with fixed final observation.
76+
(local_name_fin_obs, "reward_preprocessing.procgen:make_fin_obs_procgen"),
77+
# Fixed-length procgen with fixed final observation.
78+
(
79+
local_name_fin_obs_autoreset,
80+
"reward_preprocessing.procgen:make_fin_obs_auto_reset_procgen",
81+
),
82+
]
83+
for (local_name_fn, entry_point) in to_register:
84+
for gym_spec in gym_procgen_env_specs:
85+
gym.envs.registration.register(
86+
id=local_name_fn(gym_spec),
87+
entry_point=entry_point,
88+
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
89+
kwargs=dict(procgen_env_id=gym_spec.id),
90+
)
6391

6492

6593
class ProcgenFinalObsWrapper(gym.Wrapper):

src/reward_preprocessing/scripts/train_pref_comparison.py

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def coinrun():
2222
)
2323
common = dict(
2424
env_name="procgen:procgen-coinrun-autoreset-v0",
25+
# Limit the length of episodes. When using autoreset this is necessary since
26+
# episodes never end.
27+
# This should probably be set to 1000 at maximum, since there is the hard-coded
28+
# timeout after 1000 steps in all procgen environments.
29+
max_episode_steps=1000,
2530
num_vec=256, # Goal Misg paper uses 64 envs for each of 4 workers.
2631
env_make_kwargs=dict(num_levels=100_000, distribution_mode="hard"),
2732
)

tests/test_autoreset_wrapper.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Tests for AutoResetWrapper.
2+
3+
Taken from seals/tests/test_wrappers.py"""
4+
from seals import util
5+
from seals.testing import envs
6+
7+
8+
def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2):
9+
"""This test also exists in seals. The advantage of also having it here is that
10+
if we decide to update our version of seals this test will show us whether there
11+
were any changes in the parts of seals that we care about.
12+
13+
Check that AutoResetWrapper returns correct values from step and reset.
14+
AutoResetWrapper that pads trajectory with an extra transition containing the
15+
terminal observations.
16+
Also check that calls to .reset() do not interfere with automatic resets.
17+
Due to the padding the number of steps counted inside the environment and the number
18+
of steps performed outside the environment, i.e. the number of actions performed,
19+
will differ. This test checks that this difference is consistent.
20+
"""
21+
env = util.AutoResetWrapper(
22+
envs.CountingEnv(episode_length=episode_length),
23+
discard_terminal_observation=False,
24+
)
25+
26+
for _ in range(n_manual_reset):
27+
obs = env.reset()
28+
assert obs == 0
29+
30+
# We count the number of episodes, so we can sanity check the padding.
31+
num_episodes = 0
32+
next_episode_end = episode_length
33+
for t in range(1, n_steps + 1):
34+
act = env.action_space.sample()
35+
obs, rew, done, info = env.step(act)
36+
37+
# AutoResetWrapper overrides all done signals.
38+
assert done is False
39+
40+
if t == next_episode_end:
41+
# Unlike the AutoResetWrapper that discards terminal observations,
42+
# here the final observation is returned directly, and is not stored
43+
# in the info dict.
44+
# Due to padding, for every episode the final observation is offset from
45+
# the outer step by one.
46+
assert obs == (t - num_episodes) / (num_episodes + 1)
47+
assert rew == episode_length * 10
48+
if t == next_episode_end + 1:
49+
num_episodes += 1
50+
# Because the final step returned the final observation, the initial
51+
# obs of the next episode is returned in this additional step.
52+
assert obs == 0
53+
# Consequently, the next episode end is one step later, so it is
54+
# episode_length steps from now.
55+
next_episode_end = t + episode_length
56+
57+
# Reward of the 'reset transition' is fixed to be 0.
58+
assert rew == 0
59+
60+
# Sanity check padding. Padding should be 1 for each past episode.
61+
assert (
62+
next_episode_end
63+
== (num_episodes + 1) * episode_length + num_episodes
64+
)

0 commit comments

Comments
 (0)