Skip to content

Commit ff58363

Browse files
[BugFix] Discovery obs (#137)
* amend * amend * amend
1 parent 73bb583 commit ff58363

File tree

2 files changed

+46
-41
lines changed

2 files changed

+46
-41
lines changed

tests/test_scenarios/test_discovery.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ def setup_env(
2424
self.env.seed(0)
2525

2626
@pytest.mark.parametrize("n_agents", [1, 4])
27-
def test_heuristic(self, n_agents, n_steps=50, n_envs=4):
28-
self.setup_env(n_agents=n_agents, n_envs=n_envs)
27+
@pytest.mark.parametrize("agent_lidar", [True, False])
28+
def test_heuristic(self, n_agents, agent_lidar, n_steps=50, n_envs=4):
29+
self.setup_env(n_agents=n_agents, n_envs=n_envs, use_agent_lidar=agent_lidar)
2930
policy = discovery.HeuristicPolicy(True)
3031

3132
obs = self.env.reset()

vmas/scenarios/discovery.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2828
self._min_dist_between_entities = kwargs.pop("min_dist_between_entities", 0.2)
2929
self._lidar_range = kwargs.pop("lidar_range", 0.35)
3030
self._covering_range = kwargs.pop("covering_range", 0.25)
31+
self.use_agent_lidar = kwargs.pop("use_agent_lidar", False)
3132
self._agents_per_target = kwargs.pop("agents_per_target", 2)
3233
self.targets_respawn = kwargs.pop("targets_respawn", True)
3334
self.shared_reward = kwargs.pop("shared_reward", False)
@@ -57,9 +58,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5758
)
5859

5960
# Add agents
60-
# entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
61-
# "agent"
62-
# )
61+
entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
62+
"agent"
63+
)
6364
entity_filter_targets: Callable[[Entity], bool] = lambda e: e.name.startswith(
6465
"target"
6566
)
@@ -69,24 +70,32 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
6970
name=f"agent_{i}",
7071
collide=True,
7172
shape=Sphere(radius=self.agent_radius),
72-
sensors=[
73-
# Lidar(
74-
# world,
75-
# angle_start=0.05,
76-
# angle_end=2 * torch.pi + 0.05,
77-
# n_rays=12,
78-
# max_range=self._lidar_range,
79-
# entity_filter=entity_filter_agents,
80-
# render_color=Color.BLUE,
81-
# ),
82-
Lidar(
83-
world,
84-
n_rays=15,
85-
max_range=self._lidar_range,
86-
entity_filter=entity_filter_targets,
87-
render_color=Color.GREEN,
88-
),
89-
],
73+
sensors=(
74+
[
75+
Lidar(
76+
world,
77+
n_rays=15,
78+
max_range=self._lidar_range,
79+
entity_filter=entity_filter_targets,
80+
render_color=Color.GREEN,
81+
)
82+
]
83+
+ (
84+
[
85+
Lidar(
86+
world,
87+
angle_start=0.05,
88+
angle_end=2 * torch.pi + 0.05,
89+
n_rays=12,
90+
max_range=self._lidar_range,
91+
entity_filter=entity_filter_agents,
92+
render_color=Color.BLUE,
93+
)
94+
]
95+
if self.use_agent_lidar
96+
else []
97+
)
98+
),
9099
)
91100
agent.collision_rew = torch.zeros(batch_dim, device=device)
92101
agent.covering_reward = agent.collision_rew.clone()
@@ -230,15 +239,9 @@ def agent_reward(self, agent):
230239

231240
def observation(self, agent: Agent):
232241
lidar_1_measures = agent.sensors[0].measure()
233-
# lidar_2_measures = agent.sensors[1].measure()
234242
return torch.cat(
235-
[
236-
agent.state.pos,
237-
agent.state.vel,
238-
agent.state.pos,
239-
lidar_1_measures,
240-
# lidar_2_measures,
241-
],
243+
[agent.state.pos, agent.state.vel, lidar_1_measures]
244+
+ ([agent.sensors[1].measure()] if self.use_agent_lidar else []),
242245
dim=-1,
243246
)
244247

@@ -317,24 +320,25 @@ def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Ten
317320
closest_point_on_circ_normal *= 0.1
318321
des_pos = closest_point_on_circ + closest_point_on_circ_normal
319322

320-
# Move away from other agents within visibility range
321-
lidar_agents = observation[:, 4:16]
322-
agent_visible = torch.any(lidar_agents < 0.15, dim=1)
323-
_, agent_dir_index = torch.min(lidar_agents, dim=1)
324-
agent_dir = agent_dir_index / lidar_agents.shape[1] * 2 * torch.pi
325-
agent_vec = torch.stack([torch.cos(agent_dir), torch.sin(agent_dir)], dim=1)
326-
des_pos_agent = current_pos - agent_vec * 0.1
327-
des_pos[agent_visible] = des_pos_agent[agent_visible]
328-
329323
# Move towards targets within visibility range
330-
lidar_targets = observation[:, 16:28]
324+
lidar_targets = observation[:, 4:19]
331325
target_visible = torch.any(lidar_targets < 0.3, dim=1)
332326
_, target_dir_index = torch.min(lidar_targets, dim=1)
333327
target_dir = target_dir_index / lidar_targets.shape[1] * 2 * torch.pi
334328
target_vec = torch.stack([torch.cos(target_dir), torch.sin(target_dir)], dim=1)
335329
des_pos_target = current_pos + target_vec * 0.1
336330
des_pos[target_visible] = des_pos_target[target_visible]
337331

332+
if observation.shape[-1] > 19:
333+
# Move away from other agents within visibility range
334+
lidar_agents = observation[:, 19:31]
335+
agent_visible = torch.any(lidar_agents < 0.15, dim=1)
336+
_, agent_dir_index = torch.min(lidar_agents, dim=1)
337+
agent_dir = agent_dir_index / lidar_agents.shape[1] * 2 * torch.pi
338+
agent_vec = torch.stack([torch.cos(agent_dir), torch.sin(agent_dir)], dim=1)
339+
des_pos_agent = current_pos - agent_vec * 0.1
340+
des_pos[agent_visible] = des_pos_agent[agent_visible]
341+
338342
action = torch.clamp(
339343
(des_pos - current_pos) * 10,
340344
min=-u_range,

0 commit comments

Comments
 (0)