Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove depenedncy on bsuite to support Python 3.12 #178

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ optax
pytest
wandb
pytest-cov
bsuite
tqdm
3 changes: 2 additions & 1 deletion pax/agents/ppo/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import gym
import jax.numpy as jnp
from bsuite.utils import gym_wrapper
from dm_env import Environment, TimeStep, specs

from . import gym_wrapper


class BatchedEnvs(Environment):
"""Batches a number of cartpole environments that are called sequentially.
Expand Down
211 changes: 211 additions & 0 deletions pax/agents/ppo/gym_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""bsuite adapter for OpenAI gym run-loops."""

from typing import Any, Dict, Optional, Tuple, Union

import dm_env
from dm_env import specs
import gym
from gym import spaces
import numpy as np

# OpenAI gym step format = obs, reward, is_finished, other_info
_GymTimestep = Tuple[np.ndarray, float, bool, Dict[str, Any]]


class GymFromDMEnv(gym.Env):
"""A wrapper that converts a dm_env.Environment to an OpenAI gym.Env."""

metadata = {"render.modes": ["human", "rgb_array"]}

def __init__(self, env: dm_env.Environment):
self._env = env # type: dm_env.Environment
self._last_observation = None # type: Optional[np.ndarray]
self.viewer = None
self.game_over = False # Needed for Dopamine agents.

def step(self, action: int) -> _GymTimestep:
timestep = self._env.step(action)
self._last_observation = timestep.observation
reward = timestep.reward or 0.0
if timestep.last():
self.game_over = True
return timestep.observation, reward, timestep.last(), {}

def reset(self) -> np.ndarray:
self.game_over = False
timestep = self._env.reset()
self._last_observation = timestep.observation
return timestep.observation

def render(self, mode: str = "rgb_array") -> Union[np.ndarray, bool]:
if self._last_observation is None:
raise ValueError(
"Environment not ready to render. Call reset() first."
)

if mode == "rgb_array":
return self._last_observation

if mode == "human":
if self.viewer is None:
# pylint: disable=import-outside-toplevel
# pylint: disable=g-import-not-at-top
from gym.envs.classic_control import rendering

self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(self._last_observation)
return self.viewer.isopen

@property
def action_space(self) -> spaces.Discrete:
action_spec = self._env.action_spec() # type: specs.DiscreteArray
return spaces.Discrete(action_spec.num_values)

@property
def observation_space(self) -> spaces.Box:
obs_spec = self._env.observation_spec() # type: specs.Array
if isinstance(obs_spec, specs.BoundedArray):
return spaces.Box(
low=float(obs_spec.minimum),
high=float(obs_spec.maximum),
shape=obs_spec.shape,
dtype=obs_spec.dtype,
)
return spaces.Box(
low=-float("inf"),
high=float("inf"),
shape=obs_spec.shape,
dtype=obs_spec.dtype,
)

@property
def reward_range(self) -> Tuple[float, float]:
reward_spec = self._env.reward_spec()
if isinstance(reward_spec, specs.BoundedArray):
return reward_spec.minimum, reward_spec.maximum
return -float("inf"), float("inf")

def __getattr__(self, attr):
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)


def space2spec(space: gym.Space, name: str = None):
"""Converts an OpenAI Gym space to a dm_env spec or nested structure of specs.

Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray
specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and
Dict spaces are recursively converted to tuples and dictionaries of specs.

Args:
space: The Gym space to convert.
name: Optional name to apply to all return spec(s).

Returns:
A dm_env spec or nested structure of specs, corresponding to the input
space.
"""
if isinstance(space, spaces.Discrete):
return specs.DiscreteArray(
num_values=space.n, dtype=space.dtype, name=name
)

elif isinstance(space, spaces.Box):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=space.low,
maximum=space.high,
name=name,
)

elif isinstance(space, spaces.MultiBinary):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=0.0,
maximum=1.0,
name=name,
)

elif isinstance(space, spaces.MultiDiscrete):
return specs.BoundedArray(
shape=space.shape,
dtype=space.dtype,
minimum=np.zeros(space.shape),
maximum=space.nvec,
name=name,
)

elif isinstance(space, spaces.Tuple):
return tuple(space2spec(s, name) for s in space.spaces)

elif isinstance(space, spaces.Dict):
return {
key: space2spec(value, name) for key, value in space.spaces.items()
}

else:
raise ValueError("Unexpected gym space: {}".format(space))


class DMEnvFromGym(dm_env.Environment):
"""A wrapper to convert an OpenAI Gym environment to a dm_env.Environment."""

def __init__(self, gym_env: gym.Env):
self.gym_env = gym_env
# Convert gym action and observation spaces to dm_env specs.
self._observation_spec = space2spec(
self.gym_env.observation_space, name="observations"
)
self._action_spec = space2spec(
self.gym_env.action_space, name="actions"
)
self._reset_next_step = True

def reset(self) -> dm_env.TimeStep:
self._reset_next_step = False
observation = self.gym_env.reset()
return dm_env.restart(observation)

def step(self, action: int) -> dm_env.TimeStep:
if self._reset_next_step:
return self.reset()

# Convert the gym step result to a dm_env TimeStep.
observation, reward, done, info = self.gym_env.step(action)
self._reset_next_step = done

if done:
is_truncated = info.get("TimeLimit.truncated", False)
if is_truncated:
return dm_env.truncation(reward, observation)
else:
return dm_env.termination(reward, observation)
else:
return dm_env.transition(reward, observation)

def close(self):
self.gym_env.close()

def observation_spec(self):
return self._observation_spec

def action_spec(self):
return self._action_spec
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ optax
pytest
wandb
pytest-cov
bsuite
tqdm
Loading