Skip to content

Commit ab3d103

Browse files
authored
Merge pull request #31 from HumanCompatibleAI/imitation-final-obs
Final observation and CnnRewardNet
2 parents 33b7979 + e4caa7f commit ab3d103

File tree

7 files changed

+211
-5
lines changed

7 files changed

+211
-5
lines changed

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
--extra-index-url https://download.pytorch.org/whl/cu116
33
torch==1.12.1
44
torchvision==0.13.1
5-
stable-baselines3==1.6.0
5+
stable-baselines3==1.6.1
66
sacred==0.8.2
77
numpy==1.21.2
88
gym==0.21
@@ -12,7 +12,7 @@ seals==0.1.5
1212
torch-lucent==0.1.8
1313
jupyter==1.0.0
1414
git+https://github.com/ejnnr/mazelab.git@3042551
15-
git+https://github.com/HumanCompatibleAI/imitation.git@91c66b7377
15+
git+https://github.com/HumanCompatibleAI/imitation.git@40a2a559706e50bf60d7cc388a2c36dd0d4e8619
1616
# This version includes some fixes that are not in the newest pip version
1717
git+https://github.com/openai/gym3.git@4c38246
1818
# This commit on the branch final-obs of my (PavelCz) fork of procgen includes

src/reward_preprocessing/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Register custom environments"""
2+
3+
import gym
4+
import procgen # noqa: F401
5+
6+
import reward_preprocessing.procgen as rmi_procgen # noqa: I001
7+
8+
# Procgen
9+
10+
# note that procgen was imported to add procgen environments to the gym registry
11+
12+
GYM_PROCGEN_ENV_SPECS = list(
13+
filter(rmi_procgen.supported_procgen_env, gym.envs.registry.all())
14+
)
15+
rmi_procgen.register_procgen_envs(GYM_PROCGEN_ENV_SPECS)

src/reward_preprocessing/common/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def log_img_wandb(
284284

285285

286286
def array_to_image(arr: np.ndarray, scale: int) -> PIL.Image.Image:
287-
"""Take numpy array on [0,1] scale, return PIL image."""
287+
"""Take numpy array on [0,1] scale with shape (h,w,c), return PIL image."""
288288
return Image.fromarray(np.uint8(arr * 255), mode="RGB").resize(
289289
# PIL expects tuple of (width, height), numpy's dimension 1 is width, and
290290
# dimension 0 height.

src/reward_preprocessing/models.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Tuple
33

44
import gym
5-
from imitation.rewards.reward_nets import RewardNet
5+
from imitation.rewards.reward_nets import CnnRewardNet, RewardNet
66
import numpy as np
77
import torch as th
88

@@ -11,6 +11,42 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14+
class CnnRewardNetWorkaround(CnnRewardNet):
15+
"""Identical to CnnRewardNet, except that it fixes imitation issue #644 by
16+
removing normalize_input_layer from the kwargs.
17+
TODO: Reconsider this once the underlying issue is fixed.
18+
"""
19+
20+
def __init__(
21+
self,
22+
observation_space: gym.Space,
23+
action_space: gym.Space,
24+
use_state: bool = True,
25+
use_action: bool = True,
26+
use_next_state: bool = False,
27+
use_done: bool = False,
28+
hwc_format: bool = True,
29+
**kwargs,
30+
):
31+
normalize = kwargs.pop("normalize_input_layer", None)
32+
if normalize is not None:
33+
logger.warning(
34+
f"normalize_input_layer={normalize} was provided, will be ignored. See "
35+
"imitation issue #644"
36+
)
37+
38+
super().__init__(
39+
observation_space,
40+
action_space,
41+
use_state,
42+
use_action,
43+
use_next_state,
44+
use_done,
45+
hwc_format,
46+
**kwargs,
47+
)
48+
49+
1450
class MazeRewardNet(RewardNet):
1551
def __init__(self, size: int, maze_name: str = "EmptyMaze", **kwargs):
1652
env = gym.make(f"reward_preprocessing/{maze_name}{size}-v0", **kwargs)

src/reward_preprocessing/policies/utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, Optional, Union
33

44
import gym
5-
from imitation.data.types import AnyPath, path_to_str
5+
from imitation.data.types import AnyPath
66
import numpy as np
77
from stable_baselines3 import PPO
88
from stable_baselines3.common.base_class import BaseAlgorithm
@@ -16,6 +16,13 @@
1616
Policy = Union[gym.Space, PolicyCallable, BaseAlgorithm, BasePolicy]
1717

1818

19+
def path_to_str(path: AnyPath) -> str:
20+
if isinstance(path, bytes):
21+
return path.decode()
22+
else:
23+
return str(path)
24+
25+
1926
def policy_to_callable(
2027
policy: Policy, deterministic_policy: bool = True
2128
) -> PolicyCallable:

src/reward_preprocessing/procgen.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Code to register procgen environments to train reward funcs on."""
2+
3+
import re
4+
from typing import Iterable
5+
6+
import gym
7+
from seals.util import AutoResetWrapper, get_gym_max_episode_steps
8+
9+
10+
def supported_procgen_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
11+
starts_with_procgen = gym_spec.id.startswith("procgen-")
12+
three_parts = len(re.split("-|_", gym_spec.id)) == 3
13+
return starts_with_procgen and three_parts
14+
15+
16+
def make_auto_reset_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
17+
env = AutoResetWrapper(gym.make(procgen_env_id, **make_env_kwargs))
18+
return env
19+
20+
21+
def make_fin_obs_procgen(procgen_env_id: str, **make_env_kwargs) -> gym.Env:
22+
env = ProcgenFinalObsWrapper(gym.make(procgen_env_id, **make_env_kwargs))
23+
return env
24+
25+
26+
def local_name_autoreset(gym_spec: gym.envs.registration.EnvSpec) -> str:
27+
split_str = gym_spec.id.split("-")
28+
version = split_str[-1]
29+
split_str[-1] = "autoreset"
30+
return "-".join(split_str + [version])
31+
32+
33+
def local_name_fin_obs(gym_spec: gym.envs.registration.EnvSpec) -> str:
34+
split_str = gym_spec.id.split("-")
35+
version = split_str[-1]
36+
split_str[-1] = "final-obs"
37+
return "-".join(split_str + [version])
38+
39+
40+
def register_procgen_envs(
41+
gym_procgen_env_specs: Iterable[gym.envs.registration.EnvSpec],
42+
) -> None:
43+
44+
for gym_spec in gym_procgen_env_specs:
45+
gym.register(
46+
id=local_name_autoreset(gym_spec),
47+
entry_point="reward_preprocessing.procgen:make_auto_reset_procgen",
48+
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
49+
kwargs=dict(procgen_env_id=gym_spec.id),
50+
)
51+
52+
# There are no envs that have both autoreset and final obs wrappers.
53+
# fin-obs would only affect the terminal_observation in the info dict, if it were
54+
# to be wrapped by an AutoResetWrapper. Since, at the moment, we don't use the
55+
# terminal_observation in the info dict, there is no point to combining them.
56+
for gym_spec in gym_procgen_env_specs:
57+
gym.register(
58+
id=local_name_fin_obs(gym_spec),
59+
entry_point="reward_preprocessing.procgen:make_fin_obs_procgen",
60+
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
61+
kwargs=dict(procgen_env_id=gym_spec.id),
62+
)
63+
64+
65+
class ProcgenFinalObsWrapper(gym.Wrapper):
66+
"""Returns the final observation of gym3 procgen environment, correcting for the
67+
fact that Procgen gym environments return the second-to-last observation again
68+
instead of the final observation.
69+
70+
Only works correctly when the 'done' signal coincides with the end of an episode
71+
(which is not the case when using e.g. the seals AutoResetWrapper).
72+
Requires the use of the PavelCz/procgenAISC fork, which adds the 'final_obs' value.
73+
74+
Since procgen builds on gym3, it always resets the environment after a terminal
75+
state. The final 'obs' returned when done==True will be the obs that was already
76+
returned in the previous step. In our fork of procgen, we save the true last
77+
observation of the terminated episode in the info dict. This wrapper extracts that
78+
obs and returns it.
79+
"""
80+
81+
def step(self, action):
82+
"""When done=True, returns the final_obs from the dict."""
83+
obs, rew, done, info = self.env.step(action)
84+
if done:
85+
obs = info["final_obs"]
86+
return obs, rew, done, info
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Thin wrapper around imitation's train_preference_comparisons script."""
2+
from imitation.scripts.config.train_preference_comparisons import (
3+
train_preference_comparisons_ex,
4+
)
5+
from imitation.scripts.train_preference_comparisons import main_console
6+
7+
from reward_preprocessing.env.maze import use_config
8+
from reward_preprocessing.models import CnnRewardNetWorkaround
9+
import reward_preprocessing.policies.base
10+
11+
use_config(train_preference_comparisons_ex)
12+
13+
14+
@train_preference_comparisons_ex.named_config
15+
def coinrun():
16+
"""Training with preference comparisons on coinrun."""
17+
fragment_length = 200
18+
total_comparisons = 100_000
19+
total_timesteps = 200_000_000
20+
train = dict(
21+
policy_cls=reward_preprocessing.policies.base.ImpalaPolicy,
22+
)
23+
common = dict(
24+
env_name="procgen:procgen-coinrun-autoreset-v0",
25+
num_vec=256, # Goal Misg paper uses 64 envs for each of 4 workers.
26+
env_make_kwargs=dict(num_levels=100_000, distribution_mode="hard"),
27+
)
28+
rl = dict(
29+
batch_size=256 * 256,
30+
rl_kwargs=dict(
31+
n_epochs=3,
32+
ent_coef=0.01,
33+
learning_rate=0.0005,
34+
batch_size=8192,
35+
gamma=0.999,
36+
gae_lambda=0.95,
37+
clip_range=0.2,
38+
max_grad_norm=0.5,
39+
vf_coef=0.5,
40+
normalize_advantage=True,
41+
),
42+
)
43+
reward = dict(
44+
# Use default CNN reward net, since procgen envs are image-based.
45+
# Also, hopefully, CNNs are more interpretable.
46+
net_cls=CnnRewardNetWorkaround,
47+
)
48+
locals() # make flake8 happy
49+
50+
51+
@train_preference_comparisons_ex.named_config
52+
def fast_procgen(): # Overrides some settings for fast setup for debugging purposes.
53+
rl = dict(batch_size=2, rl_kwargs=dict(batch_size=2))
54+
common = dict(num_vec=1)
55+
total_comparisons = 32
56+
fragment_length = 16
57+
total_timesteps = 64
58+
locals() # make flake8 happy
59+
60+
61+
if __name__ == "__main__": # pragma: no cover
62+
main_console()

0 commit comments

Comments
 (0)