Skip to content

Commit 6d44f7d

Browse files
committed
small nits
1 parent 821b415 commit 6d44f7d

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

tests/test_wrappers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.

tests/test_wrappers/test_gym_wrapper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
15
import gym
26
import numpy as np
37
import pytest
@@ -10,16 +14,13 @@
1014
TEST_SCENARIOS = [
1115
"balance",
1216
"discovery",
13-
"dispersion",
14-
"football",
1517
"give_way",
1618
"joint_passage",
1719
"navigation",
1820
"passage",
19-
"reverse_transport",
20-
"road_traffic",
2121
"transport",
2222
"waterfall",
23+
"simple_world_comm",
2324
]
2425

2526

@@ -93,7 +94,10 @@ def test_gym_wrapper(
9394
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
9495

9596
for _ in range(max_steps):
96-
actions = env.unwrapped.get_random_actions()
97+
actions = [
98+
env.unwrapped.get_random_action(agent).numpy()
99+
for agent in env.unwrapped.agents
100+
]
97101
obss, rews, done, info = env.step(actions)
98102
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
99103

tests/test_wrappers/test_gymnasium_vec_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def test_gymnasium_wrapper(
6868
), f"Expected info to be a dictionary but got {type(info)}"
6969

7070
for _ in range(max_steps):
71-
actions = env.unwrapped.get_random_actions()
71+
actions = [
72+
env.unwrapped.get_random_action(agent).numpy()
73+
for agent in env.unwrapped.agents
74+
]
7275
obss, rews, terminated, truncated, info = env.step(actions)
7376
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
7477

tests/test_wrappers/test_gymnasium_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def test_gymnasium_wrapper(
6565
), f"Expected info to be a dictionary but got {type(info)}"
6666

6767
for _ in range(max_steps):
68-
actions = env.unwrapped.get_random_actions()
68+
actions = [
69+
env.unwrapped.get_random_action(agent).numpy()
70+
for agent in env.unwrapped.agents
71+
]
6972
obss, rews, terminated, truncated, info = env.step(actions)
7073
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
7174

0 commit comments

Comments
 (0)