|
1 | 1 | # Copyright (c) 2022-2024.
|
2 | 2 | # ProrokLab (https://www.proroklab.org/)
|
3 | 3 | # All rights reserved.
|
4 |
| -import os |
| 4 | +import math |
| 5 | +import random |
5 | 6 | import sys
|
6 | 7 | from pathlib import Path
|
7 | 8 |
|
|
16 | 17 | def scenario_names():
|
17 | 18 | scenarios = []
|
18 | 19 | scenarios_folder = Path(__file__).parent.parent / "vmas" / "scenarios"
|
19 |
| - for _, _, filenames in os.walk(scenarios_folder): |
20 |
| - scenarios += filenames |
21 |
| - scenarios = [ |
22 |
| - scenario.split(".")[0] |
23 |
| - for scenario in scenarios |
24 |
| - if scenario.endswith(".py") and not scenario.startswith("__") |
25 |
| - ] |
| 20 | + for path in scenarios_folder.glob("**/*.py"): |
| 21 | + if path.is_file() and not path.name.startswith("__"): |
| 22 | + scenarios.append(path.stem) |
26 | 23 | return scenarios
|
27 | 24 |
|
28 | 25 |
|
| 26 | +def random_nvecs(count, l_min=2, l_max=6, n_min=2, n_max=6, seed=0): |
| 27 | + random.seed(seed) |
| 28 | + return [ |
| 29 | + [random.randint(n_min, n_max) for _ in range(random.randint(l_min, l_max))] |
| 30 | + for _ in range(count) |
| 31 | + ] |
| 32 | + |
| 33 | + |
29 | 34 | def test_all_scenarios_included():
|
30 | 35 | from vmas import debug_scenarios, mpe_scenarios, scenarios
|
31 | 36 |
|
@@ -70,6 +75,163 @@ def test_multi_discrete_actions(scenario, num_envs=10, n_steps=10):
|
70 | 75 | env.step(env.get_random_actions())
|
71 | 76 |
|
72 | 77 |
|
| 78 | +@pytest.mark.parametrize("scenario", scenario_names()) |
| 79 | +@pytest.mark.parametrize("multidiscrete_actions", [True, False]) |
| 80 | +def test_discrete_action_nvec(scenario, multidiscrete_actions, num_envs=10, n_steps=5): |
| 81 | + env = make_env( |
| 82 | + scenario=scenario, |
| 83 | + num_envs=num_envs, |
| 84 | + seed=0, |
| 85 | + multidiscrete_actions=multidiscrete_actions, |
| 86 | + continuous_actions=False, |
| 87 | + ) |
| 88 | + if ( |
| 89 | + type(env.scenario).process_action |
| 90 | + is not vmas.simulator.scenario.BaseScenario.process_action |
| 91 | + ): |
| 92 | + pytest.skip("Scenario uses a custom process_action method.") |
| 93 | + |
| 94 | + random.seed(0) |
| 95 | + for agent in env.world.agents: |
| 96 | + agent.discrete_action_nvec = [ |
| 97 | + random.randint(2, 6) for _ in range(agent.action_size) |
| 98 | + ] |
| 99 | + env.action_space = env.get_action_space() |
| 100 | + |
| 101 | + def to_multidiscrete(action, nvec): |
| 102 | + action_multi = [] |
| 103 | + for i in range(len(nvec)): |
| 104 | + n = math.prod(nvec[i + 1 :]) |
| 105 | + action_multi.append(action // n) |
| 106 | + action = action % n |
| 107 | + return torch.stack(action_multi, dim=-1) |
| 108 | + |
| 109 | + def full_nvec(agent, world): |
| 110 | + return list(agent.discrete_action_nvec) + ( |
| 111 | + [world.dim_c] if not agent.silent and world.dim_c != 0 else [] |
| 112 | + ) |
| 113 | + |
| 114 | + for _ in range(n_steps): |
| 115 | + actions = env.get_random_actions() |
| 116 | + |
| 117 | + # Check that generated actions are in the action space |
| 118 | + for a_batch, s in zip(actions, env.action_space.spaces): |
| 119 | + for a in a_batch: |
| 120 | + assert a.numpy() in s |
| 121 | + |
| 122 | + env.step(actions) |
| 123 | + |
| 124 | + if not multidiscrete_actions: |
| 125 | + actions = [ |
| 126 | + to_multidiscrete(a.squeeze(-1), full_nvec(agent, env.world)) |
| 127 | + for a, agent in zip(actions, env.world.policy_agents) |
| 128 | + ] |
| 129 | + |
| 130 | + # Check that discrete action to continuous control mapping is correct. |
| 131 | + for i_a, agent in enumerate(env.world.policy_agents): |
| 132 | + for i, n in enumerate(agent.discrete_action_nvec): |
| 133 | + a = actions[i_a][:, i] |
| 134 | + u = agent.action.u[:, i] |
| 135 | + U = agent.action.u_range_tensor[i] |
| 136 | + k = agent.action.u_multiplier_tensor[i] |
| 137 | + for aj, uj in zip(a, u): |
| 138 | + assert aj in range( |
| 139 | + n |
| 140 | + ), f"discrete action {aj} not in [0,{n-1}] (n={n}, U={U}, k={k})" |
| 141 | + if n % 2 != 0: |
| 142 | + assert ( |
| 143 | + aj != 0 or uj == 0 |
| 144 | + ), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})" |
| 145 | + assert (aj < 1 or aj > n // 2) or torch.isclose( |
| 146 | + uj / k, (2 * U * (aj - 1)) / (n - 1) - U |
| 147 | + ), f"discrete action {aj} maps to control {uj} (n={n}, U={U}, k={k})" |
| 148 | + assert (aj <= n // 2) or torch.isclose( |
| 149 | + uj / k, 2 * U * (aj / (n - 1)) - U |
| 150 | + ), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})" |
| 151 | + else: |
| 152 | + assert torch.isclose( |
| 153 | + uj / k, 2 * U * (aj / (n - 1)) - U |
| 154 | + ), f"discrete action {aj} maps to control {uj} (n={n}), U={U}, k={k})" |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.parametrize( |
| 158 | + "nvecs", list(zip(random_nvecs(10, seed=0), random_nvecs(10, seed=42))) |
| 159 | +) |
| 160 | +def test_discrete_action_nvec_discrete_to_multi( |
| 161 | + nvecs, scenario="transport", num_envs=10, n_steps=5 |
| 162 | +): |
| 163 | + kwargs = { |
| 164 | + "scenario": scenario, |
| 165 | + "num_envs": num_envs, |
| 166 | + "seed": 0, |
| 167 | + "continuous_actions": False, |
| 168 | + } |
| 169 | + env = make_env(**kwargs, multidiscrete_actions=False) |
| 170 | + env_multi = make_env(**kwargs, multidiscrete_actions=True) |
| 171 | + if ( |
| 172 | + type(env.scenario).process_action |
| 173 | + is not vmas.simulator.scenario.BaseScenario.process_action |
| 174 | + ): |
| 175 | + pytest.skip("Scenario uses a custom process_action method.") |
| 176 | + |
| 177 | + def set_nvec(agent, nvec): |
| 178 | + agent.action_size = len(nvec) |
| 179 | + agent.discrete_action_nvec = nvec |
| 180 | + agent.action.action_size = agent.action_size |
| 181 | + |
| 182 | + random.seed(0) |
| 183 | + for agent, agent_multi, nvec in zip( |
| 184 | + env.world.policy_agents, env_multi.world.policy_agents, nvecs |
| 185 | + ): |
| 186 | + set_nvec(agent, nvec) |
| 187 | + set_nvec(agent_multi, nvec) |
| 188 | + env.action_space = env.get_action_space() |
| 189 | + env_multi.action_space = env.get_action_space() |
| 190 | + |
| 191 | + def full_nvec(agent, world): |
| 192 | + return list(agent.discrete_action_nvec) + ( |
| 193 | + [world.dim_c] if not agent.silent and world.dim_c != 0 else [] |
| 194 | + ) |
| 195 | + |
| 196 | + def full_action_size(agent, world): |
| 197 | + return len(full_nvec(agent, world)) |
| 198 | + |
| 199 | + for _ in range(n_steps): |
| 200 | + actions_multi = env_multi.get_random_actions() |
| 201 | + prodss = [ |
| 202 | + [ |
| 203 | + math.prod(full_nvec(agent, env.world)[i + 1 :]) |
| 204 | + for i in range(full_action_size(agent, env.world)) |
| 205 | + ] |
| 206 | + for agent in env.world.policy_agents |
| 207 | + ] |
| 208 | + # Compute the expected mapping from multi-discrete to discrete |
| 209 | + actions = [ |
| 210 | + (a_multi * torch.tensor(prods)).sum(dim=1) |
| 211 | + for a_multi, prods in zip(actions_multi, prodss) |
| 212 | + ] |
| 213 | + |
| 214 | + env_multi.step(actions_multi) |
| 215 | + env.step(actions) |
| 216 | + |
| 217 | + # Check that both discrete and multi-discrete actions result in the |
| 218 | + # same control value |
| 219 | + for agent, agent_multi, action, action_multi in zip( |
| 220 | + env.world.policy_agents, |
| 221 | + env_multi.world.policy_agents, |
| 222 | + actions, |
| 223 | + actions_multi, |
| 224 | + ): |
| 225 | + U = agent.action.u_range_tensor |
| 226 | + k = agent.action.u_multiplier_tensor |
| 227 | + for u, u_multi, a, a_multi in zip( |
| 228 | + agent.action.u, agent_multi.action.u, action, action_multi |
| 229 | + ): |
| 230 | + assert torch.allclose( |
| 231 | + u, u_multi |
| 232 | + ), f"{u} != {u_multi} (nvec={agent.discrete_action_nvec}, a={a}, a_multi={a_multi}, U={U}, k={k})" |
| 233 | + |
| 234 | + |
73 | 235 | @pytest.mark.parametrize("scenario", scenario_names())
|
74 | 236 | def test_non_dict_spaces_actions(scenario, num_envs=10, n_steps=10):
|
75 | 237 | env = make_env(
|
|
0 commit comments