Skip to content

Commit 2a54638

Browse files
committed
gym wrapper tests for dict spaces check obs shapes matching obs key
1 parent 6d44f7d commit 2a54638

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

tests/test_wrappers/test_gym_wrapper.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,33 @@ def _check_obs_type(obss, obs_shapes, dict_space, return_numpy):
2929
assert isinstance(
3030
obss, dict
3131
), f"Expected dictionary of observations, got {type(obss)}"
32-
obss = list(obss.values())
32+
for k, obs in obss.items():
33+
obs_shape = obs_shapes[k]
34+
assert (
35+
obs.shape == obs_shape
36+
), f"Expected shape {obs_shape}, got {obs.shape}"
37+
if return_numpy:
38+
assert isinstance(
39+
obs, np.ndarray
40+
), f"Expected numpy array, got {type(obs)}"
41+
else:
42+
assert isinstance(
43+
obs, Tensor
44+
), f"Expected torch tensor, got {type(obs)}"
3345
else:
3446
assert isinstance(
3547
obss, list
3648
), f"Expected list of observations, got {type(obss)}"
37-
for o, shape in zip(obss, obs_shapes):
38-
if return_numpy:
39-
assert isinstance(o, np.ndarray), f"Expected numpy array, got {type(o)}"
40-
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}"
41-
else:
42-
assert isinstance(o, Tensor), f"Expected torch tensor, got {type(o)}"
43-
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}"
49+
for obs, shape in zip(obss, obs_shapes):
50+
assert obs.shape == shape, f"Expected shape {shape}, got {obs.shape}"
51+
if return_numpy:
52+
assert isinstance(
53+
obs, np.ndarray
54+
), f"Expected numpy array, got {type(obs)}"
55+
else:
56+
assert isinstance(
57+
obs, Tensor
58+
), f"Expected torch tensor, got {type(obs)}"
4459

4560

4661
@pytest.mark.parametrize("scenario", TEST_SCENARIOS)
@@ -74,9 +89,9 @@ def test_gym_wrapper(
7489
assert isinstance(
7590
env.action_space, gym.spaces.Dict
7691
), "Expected Dict action space"
77-
obs_shapes = [
78-
obs_space.shape for obs_space in env.observation_space.spaces.values()
79-
]
92+
obs_shapes = {
93+
k: obs_space.shape for k, obs_space in env.observation_space.spaces.items()
94+
}
8095
else:
8196
assert isinstance(
8297
env.observation_space, gym.spaces.Tuple

tests/test_wrappers/test_gymnasium_vec_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def test_gymnasium_wrapper(
4949
assert isinstance(
5050
env.action_space, gym.spaces.Dict
5151
), "Expected Dict action space"
52-
obs_shapes = [
53-
obs_space.shape for obs_space in env.observation_space.spaces.values()
54-
]
52+
obs_shapes = {
53+
k: obs_space.shape for k, obs_space in env.observation_space.spaces.items()
54+
}
5555
else:
5656
assert isinstance(
5757
env.observation_space, gym.spaces.Tuple

tests/test_wrappers/test_gymnasium_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def test_gymnasium_wrapper(
4242
assert isinstance(
4343
env.action_space, gym.spaces.Dict
4444
), "Expected Dict action space"
45-
obs_shapes = [
46-
obs_space.shape for obs_space in env.observation_space.spaces.values()
47-
]
45+
obs_shapes = {
46+
k: obs_space.shape for k, obs_space in env.observation_space.spaces.items()
47+
}
4848
else:
4949
assert isinstance(
5050
env.observation_space, gym.spaces.Tuple

0 commit comments

Comments
 (0)