@@ -24,8 +24,15 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
24
24
self .n_agents = kwargs .pop ("n_agents" , 4 )
25
25
self .collisions = kwargs .pop ("collisions" , True )
26
26
27
- self .x_semidim = kwargs .pop ("x_semidim" , None )
28
- self .y_semidim = kwargs .pop ("y_semidim" , None )
27
+ self .world_spawning_x = kwargs .pop (
28
+ "world_spawning_x" , 1
29
+ ) # X-coordinate limit for entities spawning
30
+ self .world_spawning_y = kwargs .pop (
31
+ "world_spawning_y" , 1
32
+ ) # Y-coordinate limit for entities spawning
33
+ self .enforce_bounds = kwargs .pop (
34
+ "enforce_bounds" , False
35
+ ) # If False, the world is unlimited; else, constrained by world_spawning_x and world_spawning_y.
29
36
30
37
self .agents_with_same_goal = kwargs .pop ("agents_with_same_goal" , 1 )
31
38
self .split_goals = kwargs .pop ("split_goals" , False )
@@ -43,9 +50,15 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
43
50
ScenarioUtils .check_kwargs_consumed (kwargs )
44
51
45
52
self .min_distance_between_entities = self .agent_radius * 2 + 0.05
46
- self .world_semidim = 1
47
53
self .min_collision_distance = 0.005
48
54
55
+ if self .enforce_bounds :
56
+ self .x_semidim = self .world_spawning_x
57
+ self .y_semidim = self .world_spawning_y
58
+ else :
59
+ self .x_semidim = None
60
+ self .y_semidim = None
61
+
49
62
assert 1 <= self .agents_with_same_goal <= self .n_agents
50
63
if self .agents_with_same_goal > 1 :
51
64
assert (
@@ -135,8 +148,8 @@ def reset_world_at(self, env_index: int = None):
135
148
self .world ,
136
149
env_index ,
137
150
self .min_distance_between_entities ,
138
- (- self .world_semidim , self .world_semidim ),
139
- (- self .world_semidim , self .world_semidim ),
151
+ (- self .world_spawning_x , self .world_spawning_x ),
152
+ (- self .world_spawning_y , self .world_spawning_y ),
140
153
)
141
154
142
155
occupied_positions = torch .stack (
@@ -152,8 +165,8 @@ def reset_world_at(self, env_index: int = None):
152
165
env_index = env_index ,
153
166
world = self .world ,
154
167
min_dist_between_entities = self .min_distance_between_entities ,
155
- x_bounds = (- self .world_semidim , self .world_semidim ),
156
- y_bounds = (- self .world_semidim , self .world_semidim ),
168
+ x_bounds = (- self .world_spawning_x , self .world_spawning_x ),
169
+ y_bounds = (- self .world_spawning_y , self .world_spawning_y ),
157
170
)
158
171
goal_poses .append (position .squeeze (1 ))
159
172
occupied_positions = torch .cat ([occupied_positions , position ], dim = 1 )
0 commit comments