Skip to content

Commit 146341b

Browse files
committed
[BugFix] Fix broken gym tests (#1980)
1 parent cb9c8c5 commit 146341b

File tree

16 files changed

+711
-553
lines changed

16 files changed

+711
-553
lines changed

.github/unittest/linux/scripts/run_all.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,11 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro
194194
if [ "${CU_VERSION:-}" != cpu ] ; then
195195
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
196196
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
197-
--timeout=120
197+
--timeout=120 --mp_fork_if_no_cuda
198198
else
199199
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
200200
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
201-
--timeout=120
201+
--timeout=120 --mp_fork_if_no_cuda
202202
fi
203203

204204
coverage combine

.github/unittest/linux_distributed/scripts/run_test.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ export BATCHED_PIPE_TIMEOUT=60
2323

2424
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
2525
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'
26-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200
26+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py --instafail -v --durations 200 --mp_fork_if_no_cuda
2727
coverage combine
2828
coverage xml -i

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

+26-49
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-de
1919
# solves "'extras_require' must be a dictionary"
2020
pip install setuptools==65.3.0
2121

22-
mkdir third_party
23-
cd third_party
24-
git clone https://github.com/vmoens/gym
25-
cd ..
22+
#mkdir -p third_party
23+
#cd third_party
24+
#git clone https://github.com/vmoens/gym
25+
#cd ..
2626

2727
# This version is installed initially (see environment.yml)
2828
for GYM_VERSION in '0.13'
@@ -38,7 +38,7 @@ do
3838

3939
# delete the conda copy
4040
conda deactivate
41-
conda env remove --prefix ./cloned_env
41+
conda env remove --prefix ./cloned_env -y
4242
done
4343

4444
# gym[atari]==0.19 is broken, so we install only gym without dependencies.
@@ -57,7 +57,7 @@ do
5757

5858
# delete the conda copy
5959
conda deactivate
60-
conda env remove --prefix ./cloned_env
60+
conda env remove --prefix ./cloned_env -y
6161
done
6262

6363
# gym[atari]==0.20 installs ale-py==0.8, but this version is not compatible with gym<0.26, so we downgrade it.
@@ -76,7 +76,7 @@ do
7676

7777
# delete the conda copy
7878
conda deactivate
79-
conda env remove --prefix ./cloned_env
79+
conda env remove --prefix ./cloned_env -y
8080
done
8181

8282
for GYM_VERSION in '0.25'
@@ -92,7 +92,7 @@ do
9292

9393
# delete the conda copy
9494
conda deactivate
95-
conda env remove --prefix ./cloned_env
95+
conda env remove --prefix ./cloned_env -y
9696
done
9797

9898
# For this version "gym[accept-rom-license]" is required.
@@ -104,65 +104,42 @@ do
104104
conda activate ./cloned_env
105105

106106
echo "Testing gym version: ${GYM_VERSION}"
107-
pip3 install 'gym[accept-rom-license]'==$GYM_VERSION
108-
pip3 install 'gym[atari]'==$GYM_VERSION
107+
pip3 install 'gym[atari,accept-rom-license]'==$GYM_VERSION
109108
pip3 install gym-super-mario-bros
110109
$DIR/run_test.sh
111110

112111
# delete the conda copy
113112
conda deactivate
114-
conda env remove --prefix ./cloned_env
113+
conda env remove --prefix ./cloned_env -y
115114
done
116115

117116
# For this version "gym[accept-rom-license]" is required.
118-
for GYM_VERSION in '0.27'
117+
for GYM_VERSION in '0.27' '0.28'
119118
do
120119
# Create a copy of the conda env and work with this
121120
conda deactivate
122121
conda create --prefix ./cloned_env --clone ./env -y
123122
conda activate ./cloned_env
124123

125124
echo "Testing gym version: ${GYM_VERSION}"
126-
pip3 install 'gymnasium[accept-rom-license]'==$GYM_VERSION
127-
128-
129-
if [[ $OSTYPE != 'darwin'* ]]; then
130-
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
131-
# rename them
132-
PY_VERSION=$(python --version)
133-
if [[ $PY_VERSION == *"3.7"* ]]; then
134-
wget https://files.pythonhosted.org/packages/ab/fd/6615982d9460df7f476cad265af1378057eee9daaa8e0026de4cedbaffbd/ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
135-
pip install ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
136-
rm ale_py-0.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
137-
elif [[ $PY_VERSION == *"3.8"* ]]; then
138-
wget https://files.pythonhosted.org/packages/0f/8a/feed20571a697588bc4bfef05d6a487429c84f31406a52f8af295a0346a2/ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
139-
pip install ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
140-
rm ale_py-0.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
141-
elif [[ $PY_VERSION == *"3.9"* ]]; then
142-
wget https://files.pythonhosted.org/packages/a0/98/4316c1cedd9934f9a91b6e27a9be126043b4445594b40cfa391c8de2e5e8/ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
143-
pip install ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
144-
rm ale_py-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
145-
elif [[ $PY_VERSION == *"3.10"* ]]; then
146-
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
147-
mv ale_py-0.8.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
148-
pip install ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
149-
rm ale_py-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
150-
elif [[ $PY_VERSION == *"3.11"* ]]; then
151-
wget https://files.pythonhosted.org/packages/60/1b/3adde7f44f79fcc50d0a00a0643255e48024c4c3977359747d149dc43500/ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
152-
mv ale_py-0.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
153-
pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
154-
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
155-
fi
156-
pip install gymnasium[atari]
157-
else
158-
pip install gymnasium[atari]
159-
fi
160-
pip install mo-gymnasium
161-
pip install gymnasium-robotics
125+
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION
162126

163127
$DIR/run_test.sh
164128

165129
# delete the conda copy
166130
conda deactivate
167-
conda env remove --prefix ./cloned_env
131+
conda env remove --prefix ./cloned_env -y
168132
done
133+
134+
# Latest gymnasium
135+
conda deactivate
136+
conda create --prefix ./cloned_env --clone ./env -y
137+
conda activate ./cloned_env
138+
139+
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U
140+
141+
$DIR/run_test.sh
142+
143+
# delete the conda copy
144+
conda deactivate
145+
conda env remove --prefix ./cloned_env -y

.github/unittest/linux_libs/scripts_gym/run_test.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_te
2323

2424
export DISPLAY=':99.0'
2525
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
26-
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips
26+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 -k "gym and not isaac" --error-for-skips --mp_fork
2727
coverage combine
2828
coverage xml -i

test/_utils_internal.py

+55-24
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torchrl.envs import MultiThreadedEnv, ObservationNorm
2525
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
2626
from torchrl.envs.libs.envpool import _has_envpool
27-
from torchrl.envs.libs.gym import _has_gym, GymEnv
27+
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv
2828
from torchrl.envs.transforms import (
2929
Compose,
3030
RewardClipping,
@@ -35,41 +35,72 @@
3535
# Specified for test_utils.py
3636
__version__ = "0.3"
3737

38-
# Default versions of the environments.
39-
CARTPOLE_VERSIONED = "CartPole-v1"
40-
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
41-
PENDULUM_VERSIONED = "Pendulum-v1"
42-
PONG_VERSIONED = "ALE/Pong-v5"
38+
39+
def CARTPOLE_VERSIONED():
40+
# load gym
41+
if gym_backend() is not None:
42+
_set_gym_environments()
43+
return _CARTPOLE_VERSIONED
44+
45+
46+
def HALFCHEETAH_VERSIONED():
47+
# load gym
48+
if gym_backend() is not None:
49+
_set_gym_environments()
50+
return _HALFCHEETAH_VERSIONED
51+
52+
53+
def PONG_VERSIONED():
54+
# load gym
55+
if gym_backend() is not None:
56+
_set_gym_environments()
57+
return _PONG_VERSIONED
58+
59+
60+
def PENDULUM_VERSIONED():
61+
# load gym
62+
if gym_backend() is not None:
63+
_set_gym_environments()
64+
return _PENDULUM_VERSIONED
65+
66+
67+
def _set_gym_environments():
68+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED
69+
70+
_CARTPOLE_VERSIONED = None
71+
_HALFCHEETAH_VERSIONED = None
72+
_PENDULUM_VERSIONED = None
73+
_PONG_VERSIONED = None
4374

4475

4576
@implement_for("gym", None, "0.21.0")
4677
def _set_gym_environments(): # noqa: F811
47-
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
78+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED
4879

49-
CARTPOLE_VERSIONED = "CartPole-v0"
50-
HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
51-
PENDULUM_VERSIONED = "Pendulum-v0"
52-
PONG_VERSIONED = "Pong-v4"
80+
_CARTPOLE_VERSIONED = "CartPole-v0"
81+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
82+
_PENDULUM_VERSIONED = "Pendulum-v0"
83+
_PONG_VERSIONED = "Pong-v4"
5384

5485

5586
@implement_for("gym", "0.21.0", None)
5687
def _set_gym_environments(): # noqa: F811
57-
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
88+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED
5889

59-
CARTPOLE_VERSIONED = "CartPole-v1"
60-
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
61-
PENDULUM_VERSIONED = "Pendulum-v1"
62-
PONG_VERSIONED = "ALE/Pong-v5"
90+
_CARTPOLE_VERSIONED = "CartPole-v1"
91+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
92+
_PENDULUM_VERSIONED = "Pendulum-v1"
93+
_PONG_VERSIONED = "ALE/Pong-v5"
6394

6495

6596
@implement_for("gymnasium")
6697
def _set_gym_environments(): # noqa: F811
67-
global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED
98+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED
6899

69-
CARTPOLE_VERSIONED = "CartPole-v1"
70-
HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
71-
PENDULUM_VERSIONED = "Pendulum-v1"
72-
PONG_VERSIONED = "ALE/Pong-v5"
100+
_CARTPOLE_VERSIONED = "CartPole-v1"
101+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
102+
_PENDULUM_VERSIONED = "Pendulum-v1"
103+
_PONG_VERSIONED = "ALE/Pong-v5"
73104

74105

75106
if _has_gym:
@@ -171,7 +202,7 @@ def create_env_fn():
171202
return GymEnv(env_name, frame_skip=frame_skip, device=device)
172203

173204
else:
174-
if env_name == PONG_VERSIONED:
205+
if env_name == PONG_VERSIONED():
175206

176207
def create_env_fn():
177208
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
@@ -250,7 +281,7 @@ def _make_multithreaded_env(
250281

251282
torch.manual_seed(0)
252283
multithreaded_kwargs = (
253-
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED else {}
284+
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {}
254285
)
255286
env_multithread = MultiThreadedEnv(
256287
N,
@@ -274,7 +305,7 @@ def _make_multithreaded_env(
274305

275306
def get_transform_out(env_name, transformed_in, obs_key=None):
276307

277-
if env_name == PONG_VERSIONED:
308+
if env_name == PONG_VERSIONED():
278309
if obs_key is None:
279310
obs_key = "pixels"
280311

test/conftest.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
import os
7-
6+
import sys
87
import time
98
import warnings
109
from collections import defaultdict
1110

1211
import pytest
1312

1413
CALL_TIMES = defaultdict(lambda: 0.0)
14+
IS_OSX = sys.platform == "darwin"
1515

1616

1717
def pytest_sessionfinish(maxprint=50):
@@ -97,6 +97,20 @@ def pytest_addoption(parser):
9797
"--runslow", action="store_true", default=False, help="run slow tests"
9898
)
9999

100+
parser.addoption(
101+
"--mp_fork",
102+
action="store_true",
103+
default=False,
104+
help="Use 'fork' start method for mp dedicated tests.",
105+
)
106+
107+
parser.addoption(
108+
"--mp_fork_if_no_cuda",
109+
action="store_true",
110+
default=False,
111+
help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.",
112+
)
113+
100114

101115
def pytest_configure(config):
102116
config.addinivalue_line("markers", "slow: mark test as slow to run")
@@ -110,3 +124,11 @@ def pytest_collection_modifyitems(config, items):
110124
for item in items:
111125
if "slow" in item.keywords:
112126
item.add_marker(skip_slow)
127+
128+
129+
@pytest.fixture
130+
def maybe_fork_ParallelEnv(request):
131+
# Feature available from 0.4 only
132+
from torchrl.envs import ParallelEnv
133+
134+
return ParallelEnv

test/smoke_test_deps.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_gym():
4848
assert _has_gym
4949
from _utils_internal import PONG_VERSIONED
5050

51-
env = GymEnv(PONG_VERSIONED)
51+
env = GymEnv(PONG_VERSIONED())
5252
env.reset()
5353

5454

test/test_collector.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def make_env():
598598
# This is currently necessary as the methods in GymWrapper may have mismatching backend
599599
# versions.
600600
with set_gym_backend(gym_backend()):
601-
return TransformedEnv(GymEnv(PONG_VERSIONED, frame_skip=4), StepCounter())
601+
return TransformedEnv(GymEnv(PONG_VERSIONED(), frame_skip=4), StepCounter())
602602

603603
if parallel:
604604
env = ParallelEnv(2, make_env)
@@ -1076,7 +1076,9 @@ def test_collector_vecnorm_envcreator(static_seed):
10761076
from torchrl.envs.libs.gym import GymEnv
10771077

10781078
num_envs = 4
1079-
env_make = EnvCreator(lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED), VecNorm()))
1079+
env_make = EnvCreator(
1080+
lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm())
1081+
)
10801082
env_make = ParallelEnv(num_envs, env_make)
10811083

10821084
policy = RandomPolicy(env_make.action_spec)
@@ -1293,7 +1295,7 @@ def test_collector_output_keys(
12931295

12941296
policy = SafeModule(**policy_kwargs)
12951297

1296-
env_maker = lambda: GymEnv(PENDULUM_VERSIONED)
1298+
env_maker = lambda: GymEnv(PENDULUM_VERSIONED())
12971299

12981300
policy(env_maker().reset())
12991301

@@ -1432,7 +1434,7 @@ class TestAutoWrap:
14321434
def env_maker(self):
14331435
from torchrl.envs.libs.gym import GymEnv
14341436

1435-
return lambda: GymEnv(PENDULUM_VERSIONED)
1437+
return lambda: GymEnv(PENDULUM_VERSIONED())
14361438

14371439
def _create_collector_kwargs(self, env_maker, collector_class, policy):
14381440
collector_kwargs = {

0 commit comments

Comments
 (0)