|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 | import argparse
|
| 6 | +import importlib |
| 7 | + |
6 | 8 | import time
|
7 | 9 | from sys import platform
|
8 | 10 | from typing import Optional, Union
|
|
37 | 39 | )
|
38 | 40 | from torchrl.envs.libs.brax import _has_brax, BraxEnv
|
39 | 41 | from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
|
40 |
| -from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper |
| 42 | +from torchrl.envs.libs.gym import ( |
| 43 | + _has_gym, |
| 44 | + _is_from_pixels, |
| 45 | + GymEnv, |
| 46 | + GymWrapper, |
| 47 | + MOGymEnv, |
| 48 | + MOGymWrapper, |
| 49 | +) |
41 | 50 | from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
|
42 | 51 | from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
|
43 | 52 | from torchrl.envs.libs.openml import OpenMLEnv
|
|
46 | 55 | from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv
|
47 | 56 | from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
|
48 | 57 |
|
49 |
| -D4RL_ERR = None |
50 |
| -try: |
51 |
| - import d4rl # noqa |
| 58 | +_has_d4rl = importlib.util.find_spec("d4rl") is not None |
52 | 59 |
|
53 |
| - _has_d4rl = True |
54 |
| -except Exception as err: |
55 |
| - # many things can wrong when importing d4rl :( |
56 |
| - _has_d4rl = False |
57 |
| - D4RL_ERR = err |
| 60 | +_has_mo = importlib.util.find_spec("mo_gymnasium") is not None |
58 | 61 |
|
59 |
| -SKLEARN_ERR = None |
60 |
| -try: |
61 |
| - import sklearn # noqa |
| 62 | +_has_sklearn = importlib.util.find_spec("sklearn") is not None |
62 | 63 |
|
63 |
| - _has_sklearn = True |
64 |
| -except ModuleNotFoundError as err: |
65 |
| - _has_sklearn = False |
66 |
| - SKLEARN_ERR = err |
67 | 64 |
|
68 | 65 | if _has_gym:
|
69 | 66 | try:
|
@@ -212,6 +209,46 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
|
212 | 209 | )
|
213 | 210 | check_env_specs(env)
|
214 | 211 |
|
| 212 | + @pytest.mark.parametrize("frame_skip", [1, 3]) |
| 213 | + @pytest.mark.parametrize( |
| 214 | + "from_pixels,pixels_only", |
| 215 | + [ |
| 216 | + [False, False], |
| 217 | + [True, True], |
| 218 | + [True, False], |
| 219 | + ], |
| 220 | + ) |
| 221 | + @pytest.mark.parametrize("wrapper", [True, False]) |
| 222 | + def test_mo(self, frame_skip, from_pixels, pixels_only, wrapper): |
| 223 | + if importlib.util.find_spec("gymnasium") is not None and not _has_mo: |
| 224 | + raise pytest.skip("mo-gym not found") |
| 225 | + else: |
| 226 | + # avoid skipping, which we consider as errors in the gym CI |
| 227 | + return |
| 228 | + |
| 229 | + def make_env(): |
| 230 | + import mo_gymnasium |
| 231 | + |
| 232 | + if wrapper: |
| 233 | + return MOGymWrapper( |
| 234 | + mo_gymnasium.make("minecart-v0"), |
| 235 | + frame_skip=frame_skip, |
| 236 | + from_pixels=from_pixels, |
| 237 | + pixels_only=pixels_only, |
| 238 | + ) |
| 239 | + else: |
| 240 | + return MOGymEnv( |
| 241 | + "minecart-v0", |
| 242 | + frame_skip=frame_skip, |
| 243 | + from_pixels=from_pixels, |
| 244 | + pixels_only=pixels_only, |
| 245 | + ) |
| 246 | + |
| 247 | + env = make_env() |
| 248 | + check_env_specs(env) |
| 249 | + env = SerialEnv(2, make_env) |
| 250 | + check_env_specs(env) |
| 251 | + |
215 | 252 | def test_info_reader(self):
|
216 | 253 | try:
|
217 | 254 | import gym_super_mario_bros as mario_gym
|
@@ -1240,7 +1277,7 @@ def make_vmas():
|
1240 | 1277 | assert env.rollout(max_steps=3).device == devices[1 - first]
|
1241 | 1278 |
|
1242 | 1279 |
|
1243 |
| -@pytest.mark.skipif(not _has_d4rl, reason=f"D4RL not found: {D4RL_ERR}") |
| 1280 | +@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found") |
1244 | 1281 | class TestD4RL:
|
1245 | 1282 | @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
|
1246 | 1283 | def test_terminate_on_end(self, task):
|
@@ -1333,7 +1370,7 @@ def test_d4rl_iteration(self, task, split_trajs):
|
1333 | 1370 | print(f"completed test after {time.time()-t0}s")
|
1334 | 1371 |
|
1335 | 1372 |
|
1336 |
| -@pytest.mark.skipif(not _has_sklearn, reason=f"Scikit-learn not found: {SKLEARN_ERR}") |
| 1373 | +@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") |
1337 | 1374 | @pytest.mark.parametrize(
|
1338 | 1375 | "dataset",
|
1339 | 1376 | [
|
|
0 commit comments