24
24
from torchrl .envs import MultiThreadedEnv , ObservationNorm
25
25
from torchrl .envs .batched_envs import ParallelEnv , SerialEnv
26
26
from torchrl .envs .libs .envpool import _has_envpool
27
- from torchrl .envs .libs .gym import _has_gym , GymEnv
27
+ from torchrl .envs .libs .gym import _has_gym , gym_backend , GymEnv
28
28
from torchrl .envs .transforms import (
29
29
Compose ,
30
30
RewardClipping ,
35
35
# Specified for test_utils.py
36
36
__version__ = "0.3"
37
37
38
- # Default versions of the environments.
39
- CARTPOLE_VERSIONED = "CartPole-v1"
40
- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
41
- PENDULUM_VERSIONED = "Pendulum-v1"
42
- PONG_VERSIONED = "ALE/Pong-v5"
38
+
39
+ def CARTPOLE_VERSIONED ():
40
+ # load gym
41
+ if gym_backend () is not None :
42
+ _set_gym_environments ()
43
+ return _CARTPOLE_VERSIONED
44
+
45
+
46
+ def HALFCHEETAH_VERSIONED ():
47
+ # load gym
48
+ if gym_backend () is not None :
49
+ _set_gym_environments ()
50
+ return _HALFCHEETAH_VERSIONED
51
+
52
+
53
+ def PONG_VERSIONED ():
54
+ # load gym
55
+ if gym_backend () is not None :
56
+ _set_gym_environments ()
57
+ return _PONG_VERSIONED
58
+
59
+
60
+ def PENDULUM_VERSIONED ():
61
+ # load gym
62
+ if gym_backend () is not None :
63
+ _set_gym_environments ()
64
+ return _PENDULUM_VERSIONED
65
+
66
+
67
+ def _set_gym_environments ():
68
+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
69
+
70
+ _CARTPOLE_VERSIONED = None
71
+ _HALFCHEETAH_VERSIONED = None
72
+ _PENDULUM_VERSIONED = None
73
+ _PONG_VERSIONED = None
43
74
44
75
45
76
@implement_for ("gym" , None , "0.21.0" )
46
77
def _set_gym_environments (): # noqa: F811
47
- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
78
+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
48
79
49
- CARTPOLE_VERSIONED = "CartPole-v0"
50
- HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
51
- PENDULUM_VERSIONED = "Pendulum-v0"
52
- PONG_VERSIONED = "Pong-v4"
80
+ _CARTPOLE_VERSIONED = "CartPole-v0"
81
+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
82
+ _PENDULUM_VERSIONED = "Pendulum-v0"
83
+ _PONG_VERSIONED = "Pong-v4"
53
84
54
85
55
86
@implement_for ("gym" , "0.21.0" , None )
56
87
def _set_gym_environments (): # noqa: F811
57
- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
88
+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
58
89
59
- CARTPOLE_VERSIONED = "CartPole-v1"
60
- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
61
- PENDULUM_VERSIONED = "Pendulum-v1"
62
- PONG_VERSIONED = "ALE/Pong-v5"
90
+ _CARTPOLE_VERSIONED = "CartPole-v1"
91
+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
92
+ _PENDULUM_VERSIONED = "Pendulum-v1"
93
+ _PONG_VERSIONED = "ALE/Pong-v5"
63
94
64
95
65
96
@implement_for ("gymnasium" )
66
97
def _set_gym_environments (): # noqa: F811
67
- global CARTPOLE_VERSIONED , HALFCHEETAH_VERSIONED , PENDULUM_VERSIONED , PONG_VERSIONED
98
+ global _CARTPOLE_VERSIONED , _HALFCHEETAH_VERSIONED , _PENDULUM_VERSIONED , _PONG_VERSIONED
68
99
69
- CARTPOLE_VERSIONED = "CartPole-v1"
70
- HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
71
- PENDULUM_VERSIONED = "Pendulum-v1"
72
- PONG_VERSIONED = "ALE/Pong-v5"
100
+ _CARTPOLE_VERSIONED = "CartPole-v1"
101
+ _HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
102
+ _PENDULUM_VERSIONED = "Pendulum-v1"
103
+ _PONG_VERSIONED = "ALE/Pong-v5"
73
104
74
105
75
106
if _has_gym :
@@ -171,7 +202,7 @@ def create_env_fn():
171
202
return GymEnv (env_name , frame_skip = frame_skip , device = device )
172
203
173
204
else :
174
- if env_name == PONG_VERSIONED :
205
+ if env_name == PONG_VERSIONED () :
175
206
176
207
def create_env_fn ():
177
208
base_env = GymEnv (env_name , frame_skip = frame_skip , device = device )
@@ -250,7 +281,7 @@ def _make_multithreaded_env(
250
281
251
282
torch .manual_seed (0 )
252
283
multithreaded_kwargs = (
253
- {"frame_skip" : frame_skip } if env_name == PONG_VERSIONED else {}
284
+ {"frame_skip" : frame_skip } if env_name == PONG_VERSIONED () else {}
254
285
)
255
286
env_multithread = MultiThreadedEnv (
256
287
N ,
@@ -274,7 +305,7 @@ def _make_multithreaded_env(
274
305
275
306
def get_transform_out (env_name , transformed_in , obs_key = None ):
276
307
277
- if env_name == PONG_VERSIONED :
308
+ if env_name == PONG_VERSIONED () :
278
309
if obs_key is None :
279
310
obs_key = "pixels"
280
311
0 commit comments