Skip to content

Commit 4a5b53e

Browse files
[Feature] Local seed (#154)
* local seed * local seed * local seed * local seed
1 parent 06f29d4 commit 4a5b53e

File tree

2 files changed

+178
-23
lines changed

2 files changed

+178
-23
lines changed

tests/test_vmas.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,21 @@ def test_vmas_differentiable(scenario, n_steps=10, n_envs=10):
302302

303303
loss = obs[-1].mean() + rews[-1].mean()
304304
grad = torch.autograd.grad(loss, first_action)
305+
306+
307+
def test_seeding():
308+
env = make_env(scenario="balance", num_envs=2, seed=0)
309+
env.seed(0)
310+
random_obs = env.reset()[0][0, 0]
311+
env.seed(0)
312+
assert random_obs == env.reset()[0][0, 0]
313+
env.seed(0)
314+
torch.manual_seed(1)
315+
assert random_obs == env.reset()[0][0, 0]
316+
317+
torch.manual_seed(0)
318+
random_obs = torch.randn(1)
319+
torch.manual_seed(0)
320+
env.seed(1)
321+
env.reset()
322+
assert random_obs == torch.randn(1)

vmas/simulator/environment/environment.py

Lines changed: 160 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# Copyright (c) 2022-2024.
1+
# Copyright (c) 2022-2025.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
4+
import contextlib
45
import math
56
import random
67
from ctypes import byref
@@ -26,14 +27,41 @@
2627
)
2728

2829

29-
# environment for all agents in the multiagent world
30-
# currently code assumes that no agents will be created/destroyed at runtime!
30+
@contextlib.contextmanager
31+
def local_seed(vmas_random_state):
32+
torch_state = torch.random.get_rng_state()
33+
np_state = np.random.get_state()
34+
py_state = random.getstate()
35+
36+
torch.random.set_rng_state(vmas_random_state[0])
37+
np.random.set_state(vmas_random_state[1])
38+
random.setstate(vmas_random_state[2])
39+
yield
40+
vmas_random_state[0] = torch.random.get_rng_state()
41+
vmas_random_state[1] = np.random.get_state()
42+
vmas_random_state[2] = random.getstate()
43+
44+
torch.random.set_rng_state(torch_state)
45+
np.random.set_state(np_state)
46+
random.setstate(py_state)
47+
48+
3149
class Environment(TorchVectorizedObject):
50+
"""
51+
The VMAS environment
52+
"""
53+
3254
metadata = {
3355
"render.modes": ["human", "rgb_array"],
3456
"runtime.vectorized": True,
3557
}
58+
vmas_random_state = [
59+
torch.random.get_rng_state(),
60+
np.random.get_state(),
61+
random.getstate(),
62+
]
3663

64+
@local_seed(vmas_random_state)
3765
def __init__(
3866
self,
3967
scenario: BaseScenario,
@@ -68,7 +96,7 @@ def __init__(
6896
self.grad_enabled = grad_enabled
6997
self.terminated_truncated = terminated_truncated
7098

71-
observations = self.reset(seed=seed)
99+
observations = self._reset(seed=seed)
72100

73101
# configure spaces
74102
self.multidiscrete_actions = multidiscrete_actions
@@ -81,6 +109,7 @@ def __init__(
81109
self.visible_display = None
82110
self.text_lines = None
83111

112+
@local_seed(vmas_random_state)
84113
def reset(
85114
self,
86115
seed: Optional[int] = None,
@@ -92,21 +121,112 @@ def reset(
92121
Resets the environment in a vectorized way
93122
Returns observations for all envs and agents
94123
"""
124+
return self._reset(
125+
seed=seed,
126+
return_observations=return_observations,
127+
return_info=return_info,
128+
return_dones=return_dones,
129+
)
130+
131+
@local_seed(vmas_random_state)
132+
def reset_at(
133+
self,
134+
index: int,
135+
return_observations: bool = True,
136+
return_info: bool = False,
137+
return_dones: bool = False,
138+
):
139+
"""
140+
Resets the environment at index
141+
Returns observations for all agents in that environment
142+
"""
143+
return self._reset_at(
144+
index=index,
145+
return_observations=return_observations,
146+
return_info=return_info,
147+
return_dones=return_dones,
148+
)
149+
150+
@local_seed(vmas_random_state)
151+
def get_from_scenario(
152+
self,
153+
get_observations: bool,
154+
get_rewards: bool,
155+
get_infos: bool,
156+
get_dones: bool,
157+
dict_agent_names: Optional[bool] = None,
158+
):
159+
"""
160+
Get the environment data from the scenario
161+
162+
Args:
163+
get_observations (bool): whether to return the observations
164+
get_rewards (bool): whether to return the rewards
165+
get_infos (bool): whether to return the infos
166+
get_dones (bool): whether to return the dones
167+
dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys
168+
or in a list
169+
170+
Returns:
171+
The agents' data
172+
173+
"""
174+
return self._get_from_scenario(
175+
get_observations=get_observations,
176+
get_rewards=get_rewards,
177+
get_infos=get_infos,
178+
get_dones=get_dones,
179+
dict_agent_names=dict_agent_names,
180+
)
181+
182+
@local_seed(vmas_random_state)
183+
def seed(self, seed=None):
184+
"""
185+
Sets the seed for the environment
186+
Args:
187+
seed (int, optional): Seed for the environment. Defaults to None.
188+
189+
"""
190+
return self._seed(seed=seed)
191+
192+
@local_seed(vmas_random_state)
193+
def done(self):
194+
"""
195+
Get the done flags for the scenario.
196+
197+
Returns:
198+
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
199+
200+
"""
201+
return self._done()
202+
203+
def _reset(
204+
self,
205+
seed: Optional[int] = None,
206+
return_observations: bool = True,
207+
return_info: bool = False,
208+
return_dones: bool = False,
209+
):
210+
"""
211+
Resets the environment in a vectorized way
212+
Returns observations for all envs and agents
213+
"""
214+
95215
if seed is not None:
96-
self.seed(seed)
216+
self._seed(seed)
97217
# reset world
98218
self.scenario.env_reset_world_at(env_index=None)
99219
self.steps = torch.zeros(self.num_envs, device=self.device)
100220

101-
result = self.get_from_scenario(
221+
result = self._get_from_scenario(
102222
get_observations=return_observations,
103223
get_infos=return_info,
104224
get_rewards=False,
105225
get_dones=return_dones,
106226
)
107227
return result[0] if result and len(result) == 1 else result
108228

109-
def reset_at(
229+
def _reset_at(
110230
self,
111231
index: int,
112232
return_observations: bool = True,
@@ -121,7 +241,7 @@ def reset_at(
121241
self.scenario.env_reset_world_at(index)
122242
self.steps[index] = 0
123243

124-
result = self.get_from_scenario(
244+
result = self._get_from_scenario(
125245
get_observations=return_observations,
126246
get_infos=return_info,
127247
get_rewards=False,
@@ -130,7 +250,7 @@ def reset_at(
130250

131251
return result[0] if result and len(result) == 1 else result
132252

133-
def get_from_scenario(
253+
def _get_from_scenario(
134254
self,
135255
get_observations: bool,
136256
get_rewards: bool,
@@ -178,35 +298,41 @@ def get_from_scenario(
178298

179299
if self.terminated_truncated:
180300
if get_dones:
181-
terminated, truncated = self.done()
301+
terminated, truncated = self._done()
182302
result = [obs, rewards, terminated, truncated, infos]
183303
else:
184304
if get_dones:
185-
dones = self.done()
305+
dones = self._done()
186306
result = [obs, rewards, dones, infos]
187307

188308
return [data for data in result if data is not None]
189309

190-
def seed(self, seed=None):
310+
def _seed(self, seed=None):
311+
"""
312+
Sets the seed for the environment
313+
Args:
314+
seed (int, optional): Seed for the environment. Defaults to None.
315+
316+
"""
191317
if seed is None:
192318
seed = 0
193319
torch.manual_seed(seed)
194320
np.random.seed(seed)
195321
random.seed(seed)
196322
return [seed]
197323

324+
@local_seed(vmas_random_state)
198325
def step(self, actions: Union[List, Dict]):
199326
"""Performs a vectorized step on all sub environments using `actions`.
327+
200328
Args:
201-
actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape
202-
'(self.num_envs, action_size_of_agent)'.
329+
actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, action_size_of_agent)'.
330+
203331
Returns:
204-
obs: List on len 'self.n_agents' of which each element is a torch.Tensor
205-
of shape '(self.num_envs, obs_size_of_agent)'
332+
obs: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, obs_size_of_agent)'
206333
rewards: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs)'
207334
dones: Tensor of len 'self.num_envs' of which each element is a bool
208-
infos : List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric
209-
and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)'
335+
infos: List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)'
210336
211337
Examples:
212338
>>> import vmas
@@ -222,6 +348,7 @@ def step(self, actions: Union[List, Dict]):
222348
>>> obs = env.reset()
223349
>>> for _ in range(10):
224350
... obs, rews, dones, info = env.step(env.get_random_actions())
351+
225352
"""
226353
if isinstance(actions, Dict):
227354
actions_dict = actions
@@ -269,14 +396,21 @@ def step(self, actions: Union[List, Dict]):
269396

270397
self.steps += 1
271398

272-
return self.get_from_scenario(
399+
return self._get_from_scenario(
273400
get_observations=True,
274401
get_infos=True,
275402
get_rewards=True,
276403
get_dones=True,
277404
)
278405

279-
def done(self):
406+
def _done(self):
407+
"""
408+
Get the done flags for the scenario.
409+
410+
Returns:
411+
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
412+
413+
"""
280414
terminated = self.scenario.done().clone()
281415

282416
if self.max_steps is not None:
@@ -387,6 +521,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE):
387521
f"Invalid type of observation {obs} for agent {agent.name}"
388522
)
389523

524+
@local_seed(vmas_random_state)
390525
def get_random_action(self, agent: Agent) -> torch.Tensor:
391526
"""Returns a random action for the given agent.
392527
@@ -447,7 +582,7 @@ def get_random_action(self, agent: Agent) -> torch.Tensor:
447582
return action
448583

449584
def get_random_actions(self) -> Sequence[torch.Tensor]:
450-
"""Returns random actions for all agents that you can feed to :class:`step`
585+
"""Returns random actions for all agents that you can feed to :meth:`step`
451586
452587
Returns:
453588
Sequence[torch.tensor]: the random actions for the agents
@@ -612,6 +747,7 @@ def _set_action(self, action, agent):
612747
)
613748
agent.action.c += noise
614749

750+
@local_seed(vmas_random_state)
615751
def render(
616752
self,
617753
mode="human",
@@ -635,15 +771,15 @@ def render(
635771
Render function for environment using pyglet
636772
637773
On servers use mode="rgb_array" and set
774+
638775
```
639776
export DISPLAY=':99.0'
640777
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
641778
```
642779
643780
:param mode: One of human or rgb_array
644781
:param env_index: Index of the environment to render
645-
:param agent_index_focus: If specified the camera will stay on the agent with this index.
646-
If None, the camera will stay in the center and zoom out to contain all agents
782+
:param agent_index_focus: If specified the camera will stay on the agent with this index. If None, the camera will stay in the center and zoom out to contain all agents
647783
:param visualize_when_rgb: Also run human visualization when mode=="rgb_array"
648784
:param plot_position_function: A function to plot under the rendering.
649785
The function takes a numpy array with shape (n_points, 2), which represents a set of x,y values to evaluate f over and plot it
@@ -657,6 +793,7 @@ def render(
657793
:param plot_position_function_cmap_range: The range of the cmap in case plot_position_function outputs a single value
658794
:param plot_position_function_cmap_alpha: The alpha of the cmap in case plot_position_function outputs a single value
659795
:return: Rgb array or None, depending on the mode
796+
660797
"""
661798
self._check_batch_index(env_index)
662799
assert (

0 commit comments

Comments
 (0)