@@ -29,18 +29,33 @@ def _check_obs_type(obss, obs_shapes, dict_space, return_numpy):
29
29
assert isinstance (
30
30
obss , dict
31
31
), f"Expected dictionary of observations, got { type (obss )} "
32
- obss = list (obss .values ())
32
+ for k , obs in obss .items ():
33
+ obs_shape = obs_shapes [k ]
34
+ assert (
35
+ obs .shape == obs_shape
36
+ ), f"Expected shape { obs_shape } , got { obs .shape } "
37
+ if return_numpy :
38
+ assert isinstance (
39
+ obs , np .ndarray
40
+ ), f"Expected numpy array, got { type (obs )} "
41
+ else :
42
+ assert isinstance (
43
+ obs , Tensor
44
+ ), f"Expected torch tensor, got { type (obs )} "
33
45
else :
34
46
assert isinstance (
35
47
obss , list
36
48
), f"Expected list of observations, got { type (obss )} "
37
- for o , shape in zip (obss , obs_shapes ):
38
- if return_numpy :
39
- assert isinstance (o , np .ndarray ), f"Expected numpy array, got { type (o )} "
40
- assert o .shape == shape , f"Expected shape { shape } , got { o .shape } "
41
- else :
42
- assert isinstance (o , Tensor ), f"Expected torch tensor, got { type (o )} "
43
- assert o .shape == shape , f"Expected shape { shape } , got { o .shape } "
49
+ for obs , shape in zip (obss , obs_shapes ):
50
+ assert obs .shape == shape , f"Expected shape { shape } , got { obs .shape } "
51
+ if return_numpy :
52
+ assert isinstance (
53
+ obs , np .ndarray
54
+ ), f"Expected numpy array, got { type (obs )} "
55
+ else :
56
+ assert isinstance (
57
+ obs , Tensor
58
+ ), f"Expected torch tensor, got { type (obs )} "
44
59
45
60
46
61
@pytest .mark .parametrize ("scenario" , TEST_SCENARIOS )
@@ -74,9 +89,9 @@ def test_gym_wrapper(
74
89
assert isinstance (
75
90
env .action_space , gym .spaces .Dict
76
91
), "Expected Dict action space"
77
- obs_shapes = [
78
- obs_space .shape for obs_space in env .observation_space .spaces .values ()
79
- ]
92
+ obs_shapes = {
93
+ k : obs_space .shape for k , obs_space in env .observation_space .spaces .items ()
94
+ }
80
95
else :
81
96
assert isinstance (
82
97
env .observation_space , gym .spaces .Tuple
0 commit comments