@@ -43,7 +43,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
43
43
ScenarioUtils .check_kwargs_consumed (kwargs )
44
44
45
45
self .min_distance_between_entities = self .agent_radius * 2 + 0.05
46
- self .world_semidim = 1
46
+ self .world_semidim_x = 1 if self .x_semidim is None else self .x_semidim
47
+ self .world_semidim_y = 1 if self .y_semidim is None else self .y_semidim
47
48
self .min_collision_distance = 0.005
48
49
49
50
assert 1 <= self .agents_with_same_goal <= self .n_agents
@@ -135,8 +136,8 @@ def reset_world_at(self, env_index: int = None):
135
136
self .world ,
136
137
env_index ,
137
138
self .min_distance_between_entities ,
138
- (- self .world_semidim , self .world_semidim ),
139
- (- self .world_semidim , self .world_semidim ),
139
+ (- self .world_semidim_x , self .world_semidim_x ),
140
+ (- self .world_semidim_y , self .world_semidim_y ),
140
141
)
141
142
142
143
occupied_positions = torch .stack (
@@ -152,8 +153,8 @@ def reset_world_at(self, env_index: int = None):
152
153
env_index = env_index ,
153
154
world = self .world ,
154
155
min_dist_between_entities = self .min_distance_between_entities ,
155
- x_bounds = (- self .world_semidim , self .world_semidim ),
156
- y_bounds = (- self .world_semidim , self .world_semidim ),
156
+ x_bounds = (- self .world_semidim_x , self .world_semidim_x ),
157
+ y_bounds = (- self .world_semidim_y , self .world_semidim_y ),
157
158
)
158
159
goal_poses .append (position .squeeze (1 ))
159
160
occupied_positions = torch .cat ([occupied_positions , position ], dim = 1 )
0 commit comments