@@ -28,6 +28,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
28
28
self ._min_dist_between_entities = kwargs .pop ("min_dist_between_entities" , 0.2 )
29
29
self ._lidar_range = kwargs .pop ("lidar_range" , 0.35 )
30
30
self ._covering_range = kwargs .pop ("covering_range" , 0.25 )
31
+ self .use_agent_lidar = kwargs .pop ("use_agent_lidar" , False )
31
32
self ._agents_per_target = kwargs .pop ("agents_per_target" , 2 )
32
33
self .targets_respawn = kwargs .pop ("targets_respawn" , True )
33
34
self .shared_reward = kwargs .pop ("shared_reward" , False )
@@ -57,9 +58,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
57
58
)
58
59
59
60
# Add agents
60
- # entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
61
- # "agent"
62
- # )
61
+ entity_filter_agents : Callable [[Entity ], bool ] = lambda e : e .name .startswith (
62
+ "agent"
63
+ )
63
64
entity_filter_targets : Callable [[Entity ], bool ] = lambda e : e .name .startswith (
64
65
"target"
65
66
)
@@ -69,24 +70,32 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
69
70
name = f"agent_{ i } " ,
70
71
collide = True ,
71
72
shape = Sphere (radius = self .agent_radius ),
72
- sensors = [
73
- # Lidar(
74
- # world,
75
- # angle_start=0.05,
76
- # angle_end=2 * torch.pi + 0.05,
77
- # n_rays=12,
78
- # max_range=self._lidar_range,
79
- # entity_filter=entity_filter_agents,
80
- # render_color=Color.BLUE,
81
- # ),
82
- Lidar (
83
- world ,
84
- n_rays = 15 ,
85
- max_range = self ._lidar_range ,
86
- entity_filter = entity_filter_targets ,
87
- render_color = Color .GREEN ,
88
- ),
89
- ],
73
+ sensors = (
74
+ [
75
+ Lidar (
76
+ world ,
77
+ n_rays = 15 ,
78
+ max_range = self ._lidar_range ,
79
+ entity_filter = entity_filter_targets ,
80
+ render_color = Color .GREEN ,
81
+ )
82
+ ]
83
+ + (
84
+ [
85
+ Lidar (
86
+ world ,
87
+ angle_start = 0.05 ,
88
+ angle_end = 2 * torch .pi + 0.05 ,
89
+ n_rays = 12 ,
90
+ max_range = self ._lidar_range ,
91
+ entity_filter = entity_filter_agents ,
92
+ render_color = Color .BLUE ,
93
+ )
94
+ ]
95
+ if self .use_agent_lidar
96
+ else []
97
+ )
98
+ ),
90
99
)
91
100
agent .collision_rew = torch .zeros (batch_dim , device = device )
92
101
agent .covering_reward = agent .collision_rew .clone ()
@@ -230,15 +239,9 @@ def agent_reward(self, agent):
230
239
231
240
def observation (self , agent : Agent ):
232
241
lidar_1_measures = agent .sensors [0 ].measure ()
233
- # lidar_2_measures = agent.sensors[1].measure()
234
242
return torch .cat (
235
- [
236
- agent .state .pos ,
237
- agent .state .vel ,
238
- agent .state .pos ,
239
- lidar_1_measures ,
240
- # lidar_2_measures,
241
- ],
243
+ [agent .state .pos , agent .state .vel , lidar_1_measures ]
244
+ + ([agent .sensors [1 ].measure ()] if self .use_agent_lidar else []),
242
245
dim = - 1 ,
243
246
)
244
247
@@ -317,24 +320,25 @@ def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Ten
317
320
closest_point_on_circ_normal *= 0.1
318
321
des_pos = closest_point_on_circ + closest_point_on_circ_normal
319
322
320
- # Move away from other agents within visibility range
321
- lidar_agents = observation [:, 4 :16 ]
322
- agent_visible = torch .any (lidar_agents < 0.15 , dim = 1 )
323
- _ , agent_dir_index = torch .min (lidar_agents , dim = 1 )
324
- agent_dir = agent_dir_index / lidar_agents .shape [1 ] * 2 * torch .pi
325
- agent_vec = torch .stack ([torch .cos (agent_dir ), torch .sin (agent_dir )], dim = 1 )
326
- des_pos_agent = current_pos - agent_vec * 0.1
327
- des_pos [agent_visible ] = des_pos_agent [agent_visible ]
328
-
329
323
# Move towards targets within visibility range
330
- lidar_targets = observation [:, 16 : 28 ]
324
+ lidar_targets = observation [:, 4 : 19 ]
331
325
target_visible = torch .any (lidar_targets < 0.3 , dim = 1 )
332
326
_ , target_dir_index = torch .min (lidar_targets , dim = 1 )
333
327
target_dir = target_dir_index / lidar_targets .shape [1 ] * 2 * torch .pi
334
328
target_vec = torch .stack ([torch .cos (target_dir ), torch .sin (target_dir )], dim = 1 )
335
329
des_pos_target = current_pos + target_vec * 0.1
336
330
des_pos [target_visible ] = des_pos_target [target_visible ]
337
331
332
+ if observation .shape [- 1 ] > 19 :
333
+ # Move away from other agents within visibility range
334
+ lidar_agents = observation [:, 19 :31 ]
335
+ agent_visible = torch .any (lidar_agents < 0.15 , dim = 1 )
336
+ _ , agent_dir_index = torch .min (lidar_agents , dim = 1 )
337
+ agent_dir = agent_dir_index / lidar_agents .shape [1 ] * 2 * torch .pi
338
+ agent_vec = torch .stack ([torch .cos (agent_dir ), torch .sin (agent_dir )], dim = 1 )
339
+ des_pos_agent = current_pos - agent_vec * 0.1
340
+ des_pos [agent_visible ] = des_pos_agent [agent_visible ]
341
+
338
342
action = torch .clamp (
339
343
(des_pos - current_pos ) * 10 ,
340
344
min = - u_range ,
0 commit comments