diff --git a/requirements.txt b/requirements.txt index a0d9395..bf7ac18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,9 @@ numpy==1.21.2 gym==0.21 tqdm==4.62.2 wandb==0.12.1 -seals==0.1.5 +# seals==0.1.5 +# Temporarily use this commit in seals for updates to AutoResetWrapper. Switch back to official version once seals releases a new version. +git+https://github.com/HumanCompatibleAI/seals.git@de298732cda150b18af699e6816fcf42f2fc674f torch-lucent==0.1.8 jupyter==1.0.0 git+https://github.com/ejnnr/mazelab.git@3042551 diff --git a/src/reward_preprocessing/procgen.py b/src/reward_preprocessing/procgen.py index 1e4a0f6..731392c 100644 --- a/src/reward_preprocessing/procgen.py +++ b/src/reward_preprocessing/procgen.py @@ -14,11 +14,32 @@ def supported_procgen_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: def make_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env: - env = AutoResetWrapper(gym.make(procgen_env_id, **make_env_kwargs)) + """Make procgen with auto reset. Final observation is not fixed. + + That means the final observation will be a duplicate of the second to last.""" + env = AutoResetWrapper( + gym.make(procgen_env_id, **make_env_kwargs), discard_terminal_observation=False + ) + return env + + +def make_fin_obs_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env: + """Make procgen with auto reset and fixed final observation.""" + # The order of the wrappers matters here. Final obs wrapper must be applied first, + # then auto reset wrapper. This is because the final obs wrapper depends on the + # done signal, on order to fix the final observation of an episode. The auto reset + # wrapper will reset the done signal to False for the original episode end. + env = AutoResetWrapper( + ProcgenFinalObsWrapper( + gym.make(procgen_env_id, **make_env_kwargs), + ), + discard_terminal_observation=False, + ) return env def make_fin_obs_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env: + """Make procgen with fixed final observation.""" env = ProcgenFinalObsWrapper(gym.make(procgen_env_id, **make_env_kwargs)) return env @@ -37,29 +58,36 @@ def local_name_fin_obs(gym_spec: gym.envs.registration.EnvSpec) -> str: return "-".join(split_str + [version]) +def local_name_fin_obs_autoreset(gym_spec: gym.envs.registration.EnvSpec) -> str: + split_str = gym_spec.id.split("-") + version = split_str[-1] + split_str[-1] = "final-obs-autoreset" + return "-".join(split_str + [version]) + + def register_procgen_envs( gym_procgen_env_specs: Iterable[gym.envs.registration.EnvSpec], ) -> None: - for gym_spec in gym_procgen_env_specs: - gym.register( - id=local_name_autoreset(gym_spec), - entry_point="reward_preprocessing.procgen:make_auto_reset_procgen", - max_episode_steps=get_gym_max_episode_steps(gym_spec.id), - kwargs=dict(procgen_env_id=gym_spec.id), - ) - - # There are no envs that have both autoreset and final obs wrappers. - # fin-obs would only affect the terminal_observation in the info dict, if it were - # to be wrapped by an AutoResetWrapper. Since, at the moment, we don't use the - # terminal_observation in the info dict, there is no point to combining them. - for gym_spec in gym_procgen_env_specs: - gym.register( - id=local_name_fin_obs(gym_spec), - entry_point="reward_preprocessing.procgen:make_fin_obs_procgen", - max_episode_steps=get_gym_max_episode_steps(gym_spec.id), - kwargs=dict(procgen_env_id=gym_spec.id), - ) + to_register = [ + # Auto reset with original final observation behavior. + (local_name_autoreset, "reward_preprocessing.procgen:make_auto_reset_procgen"), + # Variable-length procgen with fixed final observation. + (local_name_fin_obs, "reward_preprocessing.procgen:make_fin_obs_procgen"), + # Fixed-length procgen with fixed final observation. + ( + local_name_fin_obs_autoreset, + "reward_preprocessing.procgen:make_fin_obs_auto_reset_procgen", + ), + ] + for (local_name_fn, entry_point) in to_register: + for gym_spec in gym_procgen_env_specs: + gym.envs.registration.register( + id=local_name_fn(gym_spec), + entry_point=entry_point, + max_episode_steps=get_gym_max_episode_steps(gym_spec.id), + kwargs=dict(procgen_env_id=gym_spec.id), + ) class ProcgenFinalObsWrapper(gym.Wrapper): diff --git a/src/reward_preprocessing/scripts/train_pref_comparison.py b/src/reward_preprocessing/scripts/train_pref_comparison.py index 90fdd9c..3140b18 100644 --- a/src/reward_preprocessing/scripts/train_pref_comparison.py +++ b/src/reward_preprocessing/scripts/train_pref_comparison.py @@ -22,6 +22,11 @@ def coinrun(): ) common = dict( env_name="procgen:procgen-coinrun-autoreset-v0", + # Limit the length of episodes. When using autoreset this is necessary since + # episodes never end. + # This should probably be set to 1000 at maximum, since there is the hard-coded + # timeout after 1000 steps in all procgen environments. + max_episode_steps=1000, num_vec=256, # Goal Misg paper uses 64 envs for each of 4 workers. env_make_kwargs=dict(num_levels=100_000, distribution_mode="hard"), ) diff --git a/tests/test_autoreset_wrapper.py b/tests/test_autoreset_wrapper.py new file mode 100644 index 0000000..74e2ab5 --- /dev/null +++ b/tests/test_autoreset_wrapper.py @@ -0,0 +1,64 @@ +"""Tests for AutoResetWrapper. + +Taken from seals/tests/test_wrappers.py""" +from seals import util +from seals.testing import envs + + +def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2): + """This test also exists in seals. The advantage of also having it here is that + if we decide to update our version of seals this test will show us whether there + were any changes in the parts of seals that we care about. + + Check that AutoResetWrapper returns correct values from step and reset. + AutoResetWrapper that pads trajectory with an extra transition containing the + terminal observations. + Also check that calls to .reset() do not interfere with automatic resets. + Due to the padding the number of steps counted inside the environment and the number + of steps performed outside the environment, i.e. the number of actions performed, + will differ. This test checks that this difference is consistent. + """ + env = util.AutoResetWrapper( + envs.CountingEnv(episode_length=episode_length), + discard_terminal_observation=False, + ) + + for _ in range(n_manual_reset): + obs = env.reset() + assert obs == 0 + + # We count the number of episodes, so we can sanity check the padding. + num_episodes = 0 + next_episode_end = episode_length + for t in range(1, n_steps + 1): + act = env.action_space.sample() + obs, rew, done, info = env.step(act) + + # AutoResetWrapper overrides all done signals. + assert done is False + + if t == next_episode_end: + # Unlike the AutoResetWrapper that discards terminal observations, + # here the final observation is returned directly, and is not stored + # in the info dict. + # Due to padding, for every episode the final observation is offset from + # the outer step by one. + assert obs == (t - num_episodes) / (num_episodes + 1) + assert rew == episode_length * 10 + if t == next_episode_end + 1: + num_episodes += 1 + # Because the final step returned the final observation, the initial + # obs of the next episode is returned in this additional step. + assert obs == 0 + # Consequently, the next episode end is one step later, so it is + # episode_length steps from now. + next_episode_end = t + episode_length + + # Reward of the 'reset transition' is fixed to be 0. + assert rew == 0 + + # Sanity check padding. Padding should be 1 for each past episode. + assert ( + next_episode_end + == (num_episodes + 1) * episode_length + num_episodes + )