Skip to content

Commit 4f00025

Browse files
committed
[Feature] Async environments
ghstack-source-id: 0a70ce0 Pull Request resolved: #2864
1 parent 70f5c06 commit 4f00025

File tree

6 files changed

+1118
-2
lines changed

6 files changed

+1118
-2
lines changed

docs/source/reference/envs.rst

+81
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,87 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w
427427
ParallelEnv
428428
EnvCreator
429429

430+
Async environments
431+
------------------
432+
433+
Asynchronous environments allow for parallel execution of multiple environments, which can significantly speed up the
434+
data collection process in reinforcement learning.
435+
436+
The `AsyncEnvPool` class and its subclasses provide a flexible interface for managing these environments using different
437+
backends, such as threading and multiprocessing.
438+
439+
The `AsyncEnvPool` class serves as a base class for asynchronous environment pools, providing a common interface for
440+
managing multiple environments concurrently. It supports different backends for parallel execution, such as threading
441+
and multiprocessing, and provides methods for asynchronous stepping and resetting of environments.
442+
443+
Contrary to :class:`~torchrl.envs.ParallelEnv`, :class:`~torchrl.envs.AsyncEnvPool` and its subclasses permit the
444+
execution of a given set of sub-environments while another task performed, allowing for complex asynchronous jobs to be
445+
run at the same time. For instance, it is possible to execute some environments while the policy is running based on
446+
the output of others.
447+
448+
This family of classes is particularly interesting when dealing with environments that have a high (and/or variable)
449+
latency.
450+
451+
.. note:: This class and its subclasses should work when nested in with :class:`~torchrl.envs.TransformedEnv` and
452+
batched environments, but users won't currently be able to use the async features of the base environment when
453+
it's nested in these classes. One should prefer nested transformed envs within an `AsyncEnvPool` instead.
454+
If this is not possible, please raise an issue.
455+
456+
Classes
457+
~~~~~~~
458+
459+
- :class:`~torchrl.envs.AsyncEnvPool`: A base class for asynchronous environment pools. It determines the backend
460+
implementation to use based on the provided arguments and manages the lifecycle of the environments.
461+
- :class:`~torchrl.envs.ProcessorAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using
462+
multiprocessing for parallel execution of environments. This class manages a pool of environments, each running in
463+
its own process, and provides methods for asynchronous stepping and resetting of environments using inter-process
464+
communication. It is automatically instantiated when `"multiprocessing"` is passed as a backend during the
465+
:class:`~torchrl.envs.AsyncEnvPool` instantiation.
466+
- :class:`~torchrl.envs.ThreadingAsyncEnvPool`: An implementation of :class:`~torchrl.envs.AsyncEnvPool` using
467+
threading for parallel execution of environments. This class manages a pool of environments, each running in its own
468+
thread, and provides methods for asynchronous stepping and resetting of environments using a thread pool executor.
469+
It is automatically instantiated when `"threading"` is passed as a backend during the
470+
:class:`~torchrl.envs.AsyncEnvPool` instantiation.
471+
472+
Example
473+
~~~~~~~
474+
475+
>>> from functools import partial
476+
>>> from torchrl.envs import AsyncEnvPool, GymEnv
477+
>>> import torch
478+
>>> # Choose backend
479+
>>> backend = "threading"
480+
>>> env = AsyncEnvPool(
481+
>>> [partial(GymEnv, "Pendulum-v1"), partial(GymEnv, "CartPole-v1")],
482+
>>> stack="lazy",
483+
>>> backend=backend
484+
>>> )
485+
>>> # Execute a synchronous reset
486+
>>> reset = env.reset()
487+
>>> print(reset)
488+
>>> # Execute a synchronous step
489+
>>> s = env.rand_step(reset)
490+
>>> print(s)
491+
>>> # Execute an asynchronous step in env 0
492+
>>> s0 = s[0]
493+
>>> s0["action"] = torch.randn(1).clamp(-1, 1)
494+
>>> s0["env_index"] = 0
495+
>>> env.async_step_send(s0)
496+
>>> # Receive data
497+
>>> s0_result = env.async_step_recv()
498+
>>> print('result', s0_result)
499+
>>> # Close env
500+
>>> env.close()
501+
502+
503+
.. autosummary::
504+
:toctree: generated/
505+
:template: rl_template.rst
506+
507+
AsyncEnvPool
508+
ProcessorAsyncEnvPool
509+
ThreadingAsyncEnvPool
510+
430511

431512
Custom native TorchRL environments
432513
----------------------------------

test/test_env.py

+63
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
4343
from torchrl.data.tensor_specs import Categorical, Composite, NonTensor, Unbounded
4444
from torchrl.envs import (
45+
AsyncEnvPool,
4546
CatFrames,
4647
CatTensors,
4748
ChessEnv,
@@ -4996,6 +4997,68 @@ def policy(td):
49964997
assert "done" in r
49974998

49984999

5000+
class TestAsyncEnvPool:
5001+
def make_env(self, *, makers, backend):
5002+
return AsyncEnvPool(makers, backend=backend)
5003+
5004+
@pytest.fixture(scope="module")
5005+
def make_envs(self):
5006+
yield [
5007+
partial(CountingEnv),
5008+
partial(CountingEnv),
5009+
partial(CountingEnv),
5010+
partial(CountingEnv),
5011+
]
5012+
5013+
@pytest.mark.parametrize("backend", ["multiprocessing", "threading"])
5014+
def test_specs(self, backend, make_envs):
5015+
env = self.make_env(makers=make_envs, backend=backend)
5016+
assert env.batch_size == (4,)
5017+
try:
5018+
r = env.reset()
5019+
assert r.shape == env.shape
5020+
s = env.rand_step(r)
5021+
assert s.shape == env.shape
5022+
env.check_env_specs(break_when_any_done="both")
5023+
finally:
5024+
env._maybe_shutdown()
5025+
5026+
@pytest.mark.parametrize("backend", ["multiprocessing", "threading"])
5027+
@pytest.mark.parametrize("min_get", [None, 1, 2])
5028+
@set_capture_non_tensor_stack(False)
5029+
def test_async_reset_and_step(self, backend, make_envs, min_get):
5030+
env = self.make_env(makers=make_envs, backend=backend)
5031+
try:
5032+
env.async_reset_send(
5033+
TensorDict(
5034+
{env._env_idx_key: NonTensorStack(*range(env.batch_size.numel()))},
5035+
batch_size=env.batch_size,
5036+
)
5037+
)
5038+
r = env.async_reset_recv(min_get=min_get)
5039+
if min_get is not None:
5040+
assert r.numel() >= min_get
5041+
assert env._env_idx_key in r
5042+
# take an action
5043+
r.set("action", torch.ones(r.shape + (1,)))
5044+
env.async_step_send(r)
5045+
s = env.async_step_recv(min_get=min_get)
5046+
if min_get is not None:
5047+
assert s.numel() >= min_get
5048+
assert env._env_idx_key in s
5049+
finally:
5050+
env._maybe_shutdown()
5051+
5052+
@pytest.mark.parametrize("backend", ["multiprocessing", "threading"])
5053+
def test_async_transformed(self, backend, make_envs):
5054+
base_env = self.make_env(makers=make_envs, backend=backend)
5055+
try:
5056+
env = TransformedEnv(base_env, StepCounter())
5057+
env.check_env_specs(break_when_any_done="both")
5058+
finally:
5059+
base_env._maybe_shutdown()
5060+
5061+
49995062
if __name__ == "__main__":
50005063
args, unknown = argparse.ArgumentParser().parse_known_args()
50015064
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_tensordictmodules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_stateful(self, safe, spec_type, lazy):
187187

188188
# test bounds
189189
if not safe and spec_type == "bounded":
190-
assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any()
190+
assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any(), td.get("out")
191191
elif safe and spec_type == "bounded":
192192
assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all()
193193

torchrl/envs/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .async_envs import AsyncEnvPool, ProcessorAsyncEnvPool, ThreadingAsyncEnvPool
67
from .batched_envs import ParallelEnv, SerialEnv
78
from .common import EnvBase, EnvMetaData, make_tensordict
89
from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
@@ -135,6 +136,9 @@
135136
"VecNormV2",
136137
"AutoResetEnv",
137138
"AutoResetTransform",
139+
"AsyncEnvPool",
140+
"ProcessorAsyncEnvPool",
141+
"ThreadingAsyncEnvPool",
138142
"BatchSizeTransform",
139143
"BinarizeReward",
140144
"BraxEnv",

0 commit comments

Comments
 (0)