100
100
101
101
102
102
@pytest .mark .skipif (not _has_gym , reason = "no gym library found" )
103
- @pytest .mark .parametrize (
104
- "env_name" ,
105
- [
106
- PONG_VERSIONED ,
107
- # PENDULUM_VERSIONED,
108
- HALFCHEETAH_VERSIONED ,
109
- ],
110
- )
111
- @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
112
- @pytest .mark .parametrize (
113
- "from_pixels,pixels_only" ,
114
- [
115
- [False , False ],
116
- [True , True ],
117
- [True , False ],
118
- ],
119
- )
120
103
class TestGym :
104
+ @pytest .mark .parametrize (
105
+ "env_name" ,
106
+ [
107
+ PONG_VERSIONED ,
108
+ # PENDULUM_VERSIONED,
109
+ HALFCHEETAH_VERSIONED ,
110
+ ],
111
+ )
112
+ @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
113
+ @pytest .mark .parametrize (
114
+ "from_pixels,pixels_only" ,
115
+ [
116
+ [False , False ],
117
+ [True , True ],
118
+ [True , False ],
119
+ ],
120
+ )
121
121
def test_gym (self , env_name , frame_skip , from_pixels , pixels_only ):
122
122
if env_name == PONG_VERSIONED and not from_pixels :
123
123
# raise pytest.skip("already pixel")
@@ -176,6 +176,23 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
176
176
assert final_seed0 == final_seed2
177
177
assert_allclose_td (tdrollout [0 ], rollout2 , rtol = RTOL , atol = ATOL )
178
178
179
+ @pytest .mark .parametrize (
180
+ "env_name" ,
181
+ [
182
+ PONG_VERSIONED ,
183
+ # PENDULUM_VERSIONED,
184
+ HALFCHEETAH_VERSIONED ,
185
+ ],
186
+ )
187
+ @pytest .mark .parametrize ("frame_skip" , [1 , 3 ])
188
+ @pytest .mark .parametrize (
189
+ "from_pixels,pixels_only" ,
190
+ [
191
+ [False , False ],
192
+ [True , True ],
193
+ [True , False ],
194
+ ],
195
+ )
179
196
def test_gym_fake_td (self , env_name , frame_skip , from_pixels , pixels_only ):
180
197
if env_name == PONG_VERSIONED and not from_pixels :
181
198
# raise pytest.skip("already pixel")
@@ -195,6 +212,37 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
195
212
)
196
213
check_env_specs (env )
197
214
215
+ def test_info_reader (self ):
216
+ try :
217
+ import gym_super_mario_bros as mario_gym
218
+ except ImportError as err :
219
+ try :
220
+ import gym
221
+
222
+ # with 0.26 we must have installed gym_super_mario_bros
223
+ # Since we capture the skips as errors, we raise a skip in this case
224
+ # Otherwise, we just return
225
+ if (
226
+ version .parse ("0.26.0" )
227
+ <= version .parse (gym .__version__ )
228
+ < version .parse ("0.27.0" )
229
+ ):
230
+ raise pytest .skip (f"no super mario bros: error=\n { err } " )
231
+ except ImportError :
232
+ pass
233
+ return
234
+
235
+ env = mario_gym .make ("SuperMarioBros-v0" , apply_api_compatibility = True )
236
+ env = GymWrapper (env )
237
+
238
+ def info_reader (info , tensordict ):
239
+ assert isinstance (info , dict ) # failed before bugfix
240
+
241
+ env .info_dict_reader = info_reader
242
+ env .reset ()
243
+ env .rand_step ()
244
+ env .rollout (3 )
245
+
198
246
199
247
@implement_for ("gym" , None , "0.26" )
200
248
def _make_gym_environment (env_name ): # noqa: F811
0 commit comments