Skip to content

Preference Comparisons with Updated AutoResetWrapper #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ numpy==1.21.2
gym==0.21
tqdm==4.62.2
wandb==0.12.1
seals==0.1.5
# seals==0.1.5
# Temporarily use this commit in seals for updates to AutoResetWrapper. Switch back to official version once seals releases a new version.
git+https://github.com/HumanCompatibleAI/seals.git@de298732cda150b18af699e6816fcf42f2fc674f
torch-lucent==0.1.8
jupyter==1.0.0
git+https://github.com/ejnnr/mazelab.git@3042551
Expand Down
68 changes: 48 additions & 20 deletions src/reward_preprocessing/procgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,32 @@ def supported_procgen_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:


def make_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
env = AutoResetWrapper(gym.make(procgen_env_id, **make_env_kwargs))
"""Make procgen with auto reset. Final observation is not fixed.

That means the final observation will be a duplicate of the second to last."""
env = AutoResetWrapper(
gym.make(procgen_env_id, **make_env_kwargs), discard_terminal_observation=False
)
return env


def make_fin_obs_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
"""Make procgen with auto reset and fixed final observation."""
# The order of the wrappers matters here. Final obs wrapper must be applied first,
# then auto reset wrapper. This is because the final obs wrapper depends on the
# done signal, on order to fix the final observation of an episode. The auto reset
# wrapper will reset the done signal to False for the original episode end.
env = AutoResetWrapper(
ProcgenFinalObsWrapper(
gym.make(procgen_env_id, **make_env_kwargs),
),
discard_terminal_observation=False,
)
return env


def make_fin_obs_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
"""Make procgen with fixed final observation."""
env = ProcgenFinalObsWrapper(gym.make(procgen_env_id, **make_env_kwargs))
return env

Expand All @@ -37,29 +58,36 @@ def local_name_fin_obs(gym_spec: gym.envs.registration.EnvSpec) -> str:
return "-".join(split_str + [version])


def local_name_fin_obs_autoreset(gym_spec: gym.envs.registration.EnvSpec) -> str:
split_str = gym_spec.id.split("-")
version = split_str[-1]
split_str[-1] = "final-obs-autoreset"
return "-".join(split_str + [version])


def register_procgen_envs(
gym_procgen_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:

for gym_spec in gym_procgen_env_specs:
gym.register(
id=local_name_autoreset(gym_spec),
entry_point="reward_preprocessing.procgen:make_auto_reset_procgen",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(procgen_env_id=gym_spec.id),
)

# There are no envs that have both autoreset and final obs wrappers.
# fin-obs would only affect the terminal_observation in the info dict, if it were
# to be wrapped by an AutoResetWrapper. Since, at the moment, we don't use the
# terminal_observation in the info dict, there is no point to combining them.
for gym_spec in gym_procgen_env_specs:
gym.register(
id=local_name_fin_obs(gym_spec),
entry_point="reward_preprocessing.procgen:make_fin_obs_procgen",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(procgen_env_id=gym_spec.id),
)
to_register = [
# Auto reset with original final observation behavior.
(local_name_autoreset, "reward_preprocessing.procgen:make_auto_reset_procgen"),
# Variable-length procgen with fixed final observation.
(local_name_fin_obs, "reward_preprocessing.procgen:make_fin_obs_procgen"),
# Fixed-length procgen with fixed final observation.
(
local_name_fin_obs_autoreset,
"reward_preprocessing.procgen:make_fin_obs_auto_reset_procgen",
),
]
for (local_name_fn, entry_point) in to_register:
for gym_spec in gym_procgen_env_specs:
gym.envs.registration.register(
id=local_name_fn(gym_spec),
entry_point=entry_point,
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(procgen_env_id=gym_spec.id),
)


class ProcgenFinalObsWrapper(gym.Wrapper):
Expand Down
5 changes: 5 additions & 0 deletions src/reward_preprocessing/scripts/train_pref_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def coinrun():
)
common = dict(
env_name="procgen:procgen-coinrun-autoreset-v0",
# Limit the length of episodes. When using autoreset this is necessary since
# episodes never end.
# This should probably be set to 1000 at maximum, since there is the hard-coded
# timeout after 1000 steps in all procgen environments.
max_episode_steps=1000,
num_vec=256, # Goal Misg paper uses 64 envs for each of 4 workers.
env_make_kwargs=dict(num_levels=100_000, distribution_mode="hard"),
)
Expand Down
64 changes: 64 additions & 0 deletions tests/test_autoreset_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Tests for AutoResetWrapper.

Taken from seals/tests/test_wrappers.py"""
from seals import util
from seals.testing import envs


def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2):
"""This test also exists in seals. The advantage of also having it here is that
if we decide to update our version of seals this test will show us whether there
were any changes in the parts of seals that we care about.

Check that AutoResetWrapper returns correct values from step and reset.
AutoResetWrapper that pads trajectory with an extra transition containing the
terminal observations.
Also check that calls to .reset() do not interfere with automatic resets.
Due to the padding the number of steps counted inside the environment and the number
of steps performed outside the environment, i.e. the number of actions performed,
will differ. This test checks that this difference is consistent.
"""
env = util.AutoResetWrapper(
envs.CountingEnv(episode_length=episode_length),
discard_terminal_observation=False,
)

for _ in range(n_manual_reset):
obs = env.reset()
assert obs == 0

# We count the number of episodes, so we can sanity check the padding.
num_episodes = 0
next_episode_end = episode_length
for t in range(1, n_steps + 1):
act = env.action_space.sample()
obs, rew, done, info = env.step(act)

# AutoResetWrapper overrides all done signals.
assert done is False

if t == next_episode_end:
# Unlike the AutoResetWrapper that discards terminal observations,
# here the final observation is returned directly, and is not stored
# in the info dict.
# Due to padding, for every episode the final observation is offset from
# the outer step by one.
assert obs == (t - num_episodes) / (num_episodes + 1)
assert rew == episode_length * 10
if t == next_episode_end + 1:
num_episodes += 1
# Because the final step returned the final observation, the initial
# obs of the next episode is returned in this additional step.
assert obs == 0
# Consequently, the next episode end is one step later, so it is
# episode_length steps from now.
next_episode_end = t + episode_length

# Reward of the 'reset transition' is fixed to be 0.
assert rew == 0

# Sanity check padding. Padding should be 1 for each past episode.
assert (
next_episode_end
== (num_episodes + 1) * episode_length + num_episodes
)