Skip to content

Commit a3f6969

Browse files
[Feature] Added support for vector-based rewards from environments in MO-Gymnasium (#992)
Co-authored-by: vmoens <[email protected]>
1 parent 24abc75 commit a3f6969

File tree

7 files changed

+125
-24
lines changed

7 files changed

+125
-24
lines changed

.circleci/unittest/linux/scripts/setup_env.sh

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
115115
fi
116116
echo "installing gymnasium"
117117
pip install "gymnasium[atari,accept-rom-license]"
118+
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
118119
else
119120
pip install "gymnasium[atari,accept-rom-license]"
121+
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
120122
fi

.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ do
141141
else
142142
pip install gymnasium[atari]
143143
fi
144+
pip install mo-gymnasium
144145

145146
$DIR/run_test.sh
146147

.circleci/unittest/linux_stable/scripts/setup_env.sh

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
115115
fi
116116
echo "installing gymnasium"
117117
pip install "gymnasium[atari,accept-rom-license]"
118+
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
118119
else
119120
pip install "gymnasium[atari,accept-rom-license]"
121+
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
120122
fi

docs/source/reference/envs.rst

+2
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ the following function will return ``1`` when queried:
481481
dm_control.DMControlWrapper
482482
gym.GymEnv
483483
gym.GymWrapper
484+
gym.MOGymEnv
485+
gym.MOGymWrapper
484486
gym.set_gym_backend
485487
gym.gym_backend
486488
habitat.HabitatEnv

test/test_libs.py

+55-18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import argparse
6+
import importlib
7+
68
import time
79
from sys import platform
810
from typing import Optional, Union
@@ -37,7 +39,14 @@
3739
)
3840
from torchrl.envs.libs.brax import _has_brax, BraxEnv
3941
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
40-
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper
42+
from torchrl.envs.libs.gym import (
43+
_has_gym,
44+
_is_from_pixels,
45+
GymEnv,
46+
GymWrapper,
47+
MOGymEnv,
48+
MOGymWrapper,
49+
)
4150
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
4251
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
4352
from torchrl.envs.libs.openml import OpenMLEnv
@@ -46,24 +55,12 @@
4655
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv
4756
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
4857

49-
D4RL_ERR = None
50-
try:
51-
import d4rl # noqa
58+
_has_d4rl = importlib.util.find_spec("d4rl") is not None
5259

53-
_has_d4rl = True
54-
except Exception as err:
55-
# many things can wrong when importing d4rl :(
56-
_has_d4rl = False
57-
D4RL_ERR = err
60+
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
5861

59-
SKLEARN_ERR = None
60-
try:
61-
import sklearn # noqa
62+
_has_sklearn = importlib.util.find_spec("sklearn") is not None
6263

63-
_has_sklearn = True
64-
except ModuleNotFoundError as err:
65-
_has_sklearn = False
66-
SKLEARN_ERR = err
6764

6865
if _has_gym:
6966
try:
@@ -212,6 +209,46 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
212209
)
213210
check_env_specs(env)
214211

212+
@pytest.mark.parametrize("frame_skip", [1, 3])
213+
@pytest.mark.parametrize(
214+
"from_pixels,pixels_only",
215+
[
216+
[False, False],
217+
[True, True],
218+
[True, False],
219+
],
220+
)
221+
@pytest.mark.parametrize("wrapper", [True, False])
222+
def test_mo(self, frame_skip, from_pixels, pixels_only, wrapper):
223+
if importlib.util.find_spec("gymnasium") is not None and not _has_mo:
224+
raise pytest.skip("mo-gym not found")
225+
else:
226+
# avoid skipping, which we consider as errors in the gym CI
227+
return
228+
229+
def make_env():
230+
import mo_gymnasium
231+
232+
if wrapper:
233+
return MOGymWrapper(
234+
mo_gymnasium.make("minecart-v0"),
235+
frame_skip=frame_skip,
236+
from_pixels=from_pixels,
237+
pixels_only=pixels_only,
238+
)
239+
else:
240+
return MOGymEnv(
241+
"minecart-v0",
242+
frame_skip=frame_skip,
243+
from_pixels=from_pixels,
244+
pixels_only=pixels_only,
245+
)
246+
247+
env = make_env()
248+
check_env_specs(env)
249+
env = SerialEnv(2, make_env)
250+
check_env_specs(env)
251+
215252
def test_info_reader(self):
216253
try:
217254
import gym_super_mario_bros as mario_gym
@@ -1240,7 +1277,7 @@ def make_vmas():
12401277
assert env.rollout(max_steps=3).device == devices[1 - first]
12411278

12421279

1243-
@pytest.mark.skipif(not _has_d4rl, reason=f"D4RL not found: {D4RL_ERR}")
1280+
@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found")
12441281
class TestD4RL:
12451282
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
12461283
def test_terminate_on_end(self, task):
@@ -1333,7 +1370,7 @@ def test_d4rl_iteration(self, task, split_trajs):
13331370
print(f"completed test after {time.time()-t0}s")
13341371

13351372

1336-
@pytest.mark.skipif(not _has_sklearn, reason=f"Scikit-learn not found: {SKLEARN_ERR}")
1373+
@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
13371374
@pytest.mark.parametrize(
13381375
"dataset",
13391376
[

test/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_set_gym_environments_no_version_gymnasium_found():
255255

256256
# this version of gymnasium does not exist in implement_for
257257
# therefore, set_gym_backend will not set anything and raise an ImportError.
258-
msg = f"could not set anything related to gym backed {gymnasium_name} with version={gymnasium_version}."
258+
msg = f"could not set anything related to gym backend {gymnasium_name} with version={gymnasium_version}."
259259
with pytest.raises(ImportError, match=msg) as exc_info:
260260
with set_gym_backend(gymnasium):
261261
_utils_internal._set_gym_environments()

torchrl/envs/libs/gym.py

+62-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
if not _has_gym:
4141
_has_gym = importlib.util.find_spec("gymnasium") is not None
4242

43+
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
44+
4345

4446
class set_gym_backend(_DecoratorContextManager):
4547
"""Sets the gym-backend to a certain value.
@@ -106,7 +108,8 @@ def _call(self):
106108
found_setter = True
107109
if not found_setter:
108110
raise ImportError(
109-
f"could not set anything related to gym backed {self.backend.__name__} with version={self.backend.__version__}."
111+
f"could not set anything related to gym backend "
112+
f"{self.backend.__name__} with version={self.backend.__version__}."
110113
)
111114

112115
def __enter__(self):
@@ -527,10 +530,17 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
527530
else:
528531
observation_spec = CompositeSpec(observation=observation_spec)
529532
self.observation_spec = observation_spec
530-
self.reward_spec = UnboundedContinuousTensorSpec(
531-
shape=[1],
532-
device=self.device,
533-
)
533+
if hasattr(env, "reward_space") and env.reward_space is not None:
534+
self.reward_spec = _gym_to_torchrl_spec_transform(
535+
env.reward_space,
536+
device=self.device,
537+
categorical_action_encoding=self._categorical_action_encoding,
538+
)
539+
else:
540+
self.reward_spec = UnboundedContinuousTensorSpec(
541+
shape=[1],
542+
device=self.device,
543+
)
534544

535545
def _init_env(self):
536546
self.reset()
@@ -671,3 +681,50 @@ def _check_kwargs(self, kwargs: Dict):
671681

672682
def __repr__(self) -> str:
673683
return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
684+
685+
686+
class MOGymWrapper(GymWrapper):
687+
"""FARAMA MO-Gymnasium environment wrapper.
688+
689+
Examples:
690+
>>> import mo_gymnasium as mo_gym
691+
>>> env = MOGymWrapper(mo_gym.make('minecart-v0'), frame_skip=4)
692+
>>> td = env.rand_step()
693+
>>> print(td)
694+
>>> print(env.available_envs)
695+
696+
"""
697+
698+
git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
699+
libname = "mo-gymnasium"
700+
701+
_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)
702+
703+
704+
class MOGymEnv(GymEnv):
705+
"""FARAMA MO-Gymnasium environment wrapper.
706+
707+
Examples:
708+
>>> env = MOGymEnv(env_name="minecart-v0", frame_skip=4)
709+
>>> td = env.rand_step()
710+
>>> print(td)
711+
>>> print(env.available_envs)
712+
713+
"""
714+
715+
git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
716+
libname = "mo-gymnasium"
717+
718+
@property
719+
def lib(self) -> ModuleType:
720+
if _has_mo:
721+
import mo_gymnasium as mo_gym
722+
723+
return mo_gym
724+
else:
725+
try:
726+
import mo_gymnasium # noqa: F401
727+
except ImportError as err:
728+
raise ImportError("MO-gymnasium not found, check installation") from err
729+
730+
_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)

0 commit comments

Comments
 (0)