Skip to content

Commit 24abc75

Browse files
xmaplesvmoens
andauthored
[BugFix] extract the info dict from a list (#1131)
Co-authored-by: vmoens <[email protected]>
1 parent 47579aa commit 24abc75

File tree

3 files changed

+67
-18
lines changed

3 files changed

+67
-18
lines changed

.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ do
9595
echo "Testing gym version: ${GYM_VERSION}"
9696
pip3 install 'gym[accept-rom-license]'==$GYM_VERSION
9797
pip3 install 'gym[atari]'==$GYM_VERSION
98+
pip3 install gym-super-mario-bros
9899
$DIR/run_test.sh
99100

100101
# delete the conda copy

test/test_libs.py

+65-17
Original file line numberDiff line numberDiff line change
@@ -100,24 +100,24 @@
100100

101101

102102
@pytest.mark.skipif(not _has_gym, reason="no gym library found")
103-
@pytest.mark.parametrize(
104-
"env_name",
105-
[
106-
PONG_VERSIONED,
107-
# PENDULUM_VERSIONED,
108-
HALFCHEETAH_VERSIONED,
109-
],
110-
)
111-
@pytest.mark.parametrize("frame_skip", [1, 3])
112-
@pytest.mark.parametrize(
113-
"from_pixels,pixels_only",
114-
[
115-
[False, False],
116-
[True, True],
117-
[True, False],
118-
],
119-
)
120103
class TestGym:
104+
@pytest.mark.parametrize(
105+
"env_name",
106+
[
107+
PONG_VERSIONED,
108+
# PENDULUM_VERSIONED,
109+
HALFCHEETAH_VERSIONED,
110+
],
111+
)
112+
@pytest.mark.parametrize("frame_skip", [1, 3])
113+
@pytest.mark.parametrize(
114+
"from_pixels,pixels_only",
115+
[
116+
[False, False],
117+
[True, True],
118+
[True, False],
119+
],
120+
)
121121
def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
122122
if env_name == PONG_VERSIONED and not from_pixels:
123123
# raise pytest.skip("already pixel")
@@ -176,6 +176,23 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
176176
assert final_seed0 == final_seed2
177177
assert_allclose_td(tdrollout[0], rollout2, rtol=RTOL, atol=ATOL)
178178

179+
@pytest.mark.parametrize(
180+
"env_name",
181+
[
182+
PONG_VERSIONED,
183+
# PENDULUM_VERSIONED,
184+
HALFCHEETAH_VERSIONED,
185+
],
186+
)
187+
@pytest.mark.parametrize("frame_skip", [1, 3])
188+
@pytest.mark.parametrize(
189+
"from_pixels,pixels_only",
190+
[
191+
[False, False],
192+
[True, True],
193+
[True, False],
194+
],
195+
)
179196
def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
180197
if env_name == PONG_VERSIONED and not from_pixels:
181198
# raise pytest.skip("already pixel")
@@ -195,6 +212,37 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
195212
)
196213
check_env_specs(env)
197214

215+
def test_info_reader(self):
216+
try:
217+
import gym_super_mario_bros as mario_gym
218+
except ImportError as err:
219+
try:
220+
import gym
221+
222+
# with 0.26 we must have installed gym_super_mario_bros
223+
# Since we capture the skips as errors, we raise a skip in this case
224+
# Otherwise, we just return
225+
if (
226+
version.parse("0.26.0")
227+
<= version.parse(gym.__version__)
228+
< version.parse("0.27.0")
229+
):
230+
raise pytest.skip(f"no super mario bros: error=\n{err}")
231+
except ImportError:
232+
pass
233+
return
234+
235+
env = mario_gym.make("SuperMarioBros-v0", apply_api_compatibility=True)
236+
env = GymWrapper(env)
237+
238+
def info_reader(info, tensordict):
239+
assert isinstance(info, dict) # failed before bugfix
240+
241+
env.info_dict_reader = info_reader
242+
env.reset()
243+
env.rand_step()
244+
env.rollout(3)
245+
198246

199247
@implement_for("gym", None, "0.26")
200248
def _make_gym_environment(env_name): # noqa: F811

torchrl/envs/gym_like.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _reset(
247247
obs, *other = self._output_transform(reset_data)
248248
info = None
249249
if len(other) == 1:
250-
info = other
250+
info = other[0]
251251

252252
tensordict_out = TensorDict(
253253
source=self.read_obs(obs),

0 commit comments

Comments
 (0)