|
1 |
| -# Copyright (c) 2022-2024. |
| 1 | +# Copyright (c) 2022-2025. |
2 | 2 | # ProrokLab (https://www.proroklab.org/)
|
3 | 3 | # All rights reserved.
|
| 4 | +import contextlib |
| 5 | +import functools |
4 | 6 | import math
|
5 | 7 | import random
|
6 | 8 | from ctypes import byref
|
|
26 | 28 | )
|
27 | 29 |
|
28 | 30 |
|
29 |
| -# environment for all agents in the multiagent world |
30 |
| -# currently code assumes that no agents will be created/destroyed at runtime! |
| 31 | +@contextlib.contextmanager |
| 32 | +def local_seed(vmas_random_state): |
| 33 | + torch_state = torch.random.get_rng_state() |
| 34 | + np_state = np.random.get_state() |
| 35 | + py_state = random.getstate() |
| 36 | + |
| 37 | + torch.random.set_rng_state(vmas_random_state[0]) |
| 38 | + np.random.set_state(vmas_random_state[1]) |
| 39 | + random.setstate(vmas_random_state[2]) |
| 40 | + yield |
| 41 | + vmas_random_state[0] = torch.random.get_rng_state() |
| 42 | + vmas_random_state[1] = np.random.get_state() |
| 43 | + vmas_random_state[2] = random.getstate() |
| 44 | + |
| 45 | + torch.random.set_rng_state(torch_state) |
| 46 | + np.random.set_state(np_state) |
| 47 | + random.setstate(py_state) |
| 48 | + |
| 49 | + |
| 50 | +def apply_local_seed(cls): |
| 51 | + """Applies the local seed to all the functions.""" |
| 52 | + for attr_name, attr_value in cls.__dict__.items(): |
| 53 | + if callable(attr_value): |
| 54 | + wrapped = attr_value # Keep reference to original method |
| 55 | + |
| 56 | + @functools.wraps(wrapped) |
| 57 | + def wrapper(self, *args, _wrapped=wrapped, **kwargs): |
| 58 | + with local_seed(cls.vmas_random_state): |
| 59 | + return _wrapped(self, *args, **kwargs) |
| 60 | + |
| 61 | + setattr(cls, attr_name, wrapper) |
| 62 | + return cls |
| 63 | + |
| 64 | + |
| 65 | +@apply_local_seed |
31 | 66 | class Environment(TorchVectorizedObject):
|
32 | 67 | metadata = {
|
33 | 68 | "render.modes": ["human", "rgb_array"],
|
34 | 69 | "runtime.vectorized": True,
|
35 | 70 | }
|
| 71 | + vmas_random_state = [ |
| 72 | + torch.random.get_rng_state(), |
| 73 | + np.random.get_state(), |
| 74 | + random.getstate(), |
| 75 | + ] |
36 | 76 |
|
37 | 77 | def __init__(
|
38 | 78 | self,
|
|
0 commit comments