Skip to content

Commit 0e23acd

Browse files
committed
local seed
1 parent 06f29d4 commit 0e23acd

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

vmas/simulator/environment/environment.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# Copyright (c) 2022-2024.
1+
# Copyright (c) 2022-2025.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
4+
import contextlib
5+
import functools
46
import math
57
import random
68
from ctypes import byref
@@ -26,13 +28,51 @@
2628
)
2729

2830

29-
# environment for all agents in the multiagent world
30-
# currently code assumes that no agents will be created/destroyed at runtime!
31+
@contextlib.contextmanager
32+
def local_seed(vmas_random_state):
33+
torch_state = torch.random.get_rng_state()
34+
np_state = np.random.get_state()
35+
py_state = random.getstate()
36+
37+
torch.random.set_rng_state(vmas_random_state[0])
38+
np.random.set_state(vmas_random_state[1])
39+
random.setstate(vmas_random_state[2])
40+
yield
41+
vmas_random_state[0] = torch.random.get_rng_state()
42+
vmas_random_state[1] = np.random.get_state()
43+
vmas_random_state[2] = random.getstate()
44+
45+
torch.random.set_rng_state(torch_state)
46+
np.random.set_state(np_state)
47+
random.setstate(py_state)
48+
49+
50+
def apply_local_seed(cls):
51+
"""Applies the local seed to all the functions."""
52+
for attr_name, attr_value in cls.__dict__.items():
53+
if callable(attr_value):
54+
wrapped = attr_value # Keep reference to original method
55+
56+
@functools.wraps(wrapped)
57+
def wrapper(self, *args, _wrapped=wrapped, **kwargs):
58+
with local_seed(cls.vmas_random_state):
59+
return _wrapped(self, *args, **kwargs)
60+
61+
setattr(cls, attr_name, wrapper)
62+
return cls
63+
64+
65+
@apply_local_seed
3166
class Environment(TorchVectorizedObject):
3267
metadata = {
3368
"render.modes": ["human", "rgb_array"],
3469
"runtime.vectorized": True,
3570
}
71+
vmas_random_state = [
72+
torch.random.get_rng_state(),
73+
np.random.get_state(),
74+
random.getstate(),
75+
]
3676

3777
def __init__(
3878
self,

0 commit comments

Comments
 (0)