@@ -14,11 +14,32 @@ def supported_procgen_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
14
14
15
15
16
16
def make_auto_reset_procgen (procgen_env_id : str , ** make_env_kwargs ) -> gym .Env :
17
- env = AutoResetWrapper (gym .make (procgen_env_id , ** make_env_kwargs ))
17
+ """Make procgen with auto reset. Final observation is not fixed.
18
+
19
+ That means the final observation will be a duplicate of the second to last."""
20
+ env = AutoResetWrapper (
21
+ gym .make (procgen_env_id , ** make_env_kwargs ), discard_terminal_observation = False
22
+ )
23
+ return env
24
+
25
+
26
+ def make_fin_obs_auto_reset_procgen (procgen_env_id : str , ** make_env_kwargs ) -> gym .Env :
27
+ """Make procgen with auto reset and fixed final observation."""
28
+ # The order of the wrappers matters here. Final obs wrapper must be applied first,
29
+ # then auto reset wrapper. This is because the final obs wrapper depends on the
30
+ # done signal, on order to fix the final observation of an episode. The auto reset
31
+ # wrapper will reset the done signal to False for the original episode end.
32
+ env = AutoResetWrapper (
33
+ ProcgenFinalObsWrapper (
34
+ gym .make (procgen_env_id , ** make_env_kwargs ),
35
+ ),
36
+ discard_terminal_observation = False ,
37
+ )
18
38
return env
19
39
20
40
21
41
def make_fin_obs_procgen (procgen_env_id : str , ** make_env_kwargs ) -> gym .Env :
42
+ """Make procgen with fixed final observation."""
22
43
env = ProcgenFinalObsWrapper (gym .make (procgen_env_id , ** make_env_kwargs ))
23
44
return env
24
45
@@ -37,29 +58,36 @@ def local_name_fin_obs(gym_spec: gym.envs.registration.EnvSpec) -> str:
37
58
return "-" .join (split_str + [version ])
38
59
39
60
61
+ def local_name_fin_obs_autoreset (gym_spec : gym .envs .registration .EnvSpec ) -> str :
62
+ split_str = gym_spec .id .split ("-" )
63
+ version = split_str [- 1 ]
64
+ split_str [- 1 ] = "final-obs-autoreset"
65
+ return "-" .join (split_str + [version ])
66
+
67
+
40
68
def register_procgen_envs (
41
69
gym_procgen_env_specs : Iterable [gym .envs .registration .EnvSpec ],
42
70
) -> None :
43
71
44
- for gym_spec in gym_procgen_env_specs :
45
- gym . register (
46
- id = local_name_autoreset ( gym_spec ),
47
- entry_point = "reward_preprocessing.procgen:make_auto_reset_procgen" ,
48
- max_episode_steps = get_gym_max_episode_steps ( gym_spec . id ),
49
- kwargs = dict ( procgen_env_id = gym_spec . id ),
50
- )
51
-
52
- # There are no envs that have both autoreset and final obs wrappers.
53
- # fin-obs would only affect the terminal_observation in the info dict, if it were
54
- # to be wrapped by an AutoResetWrapper. Since, at the moment, we don't use the
55
- # terminal_observation in the info dict, there is no point to combining them.
56
- for gym_spec in gym_procgen_env_specs :
57
- gym .register (
58
- id = local_name_fin_obs (gym_spec ),
59
- entry_point = "reward_preprocessing.procgen:make_fin_obs_procgen" ,
60
- max_episode_steps = get_gym_max_episode_steps (gym_spec .id ),
61
- kwargs = dict (procgen_env_id = gym_spec .id ),
62
- )
72
+ to_register = [
73
+ # Auto reset with original final observation behavior.
74
+ ( local_name_autoreset , "reward_preprocessing.procgen:make_auto_reset_procgen" ),
75
+ # Variable-length procgen with fixed final observation.
76
+ ( local_name_fin_obs , "reward_preprocessing.procgen:make_fin_obs_procgen" ),
77
+ # Fixed-length procgen with fixed final observation.
78
+ (
79
+ local_name_fin_obs_autoreset ,
80
+ "reward_preprocessing.procgen:make_fin_obs_auto_reset_procgen" ,
81
+ ),
82
+ ]
83
+ for ( local_name_fn , entry_point ) in to_register :
84
+ for gym_spec in gym_procgen_env_specs :
85
+ gym . envs . registration .register (
86
+ id = local_name_fn (gym_spec ),
87
+ entry_point = entry_point ,
88
+ max_episode_steps = get_gym_max_episode_steps (gym_spec .id ),
89
+ kwargs = dict (procgen_env_id = gym_spec .id ),
90
+ )
63
91
64
92
65
93
class ProcgenFinalObsWrapper (gym .Wrapper ):
0 commit comments