Skip to content

Commit da88aad

Browse files
authored
[Feature] torch.distributed collectors (#934)
1 parent ee58306 commit da88aad

28 files changed

+3939
-62
lines changed

Diff for: README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,10 @@ And it is `functorch` and `torch.compile` compatible!
277277
```
278278
</details>
279279

280-
- multiprocess [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup> that work synchronously or asynchronously.
281-
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
280+
- multiprocess and distributed [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup>
281+
that work synchronously or asynchronously.
282+
Through the use of TensorDict, TorchRL's training loops are made very similar
283+
to regular training loops in supervised
282284
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
283285
<details>
284286
<summary>Code</summary>
@@ -302,6 +304,9 @@ And it is `functorch` and `torch.compile` compatible!
302304
```
303305
</details>
304306

307+
Check our [distributed collector examples](examples/distributed/collectors) to
308+
learn more about ultra-fast data collection with TorchRL.
309+
305310
- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
306311
<details>
307312
<summary>Code</summary>

Diff for: docs/source/reference/collectors.rst

+49-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TorchRL's data collectors accept two main arguments: an environment (or a list o
1111
environment constructors) and a policy. They will iteratively execute an environment
1212
step and a policy query over a defined number of steps before delivering a stack of
1313
the data collected to the user. Environments will be reset whenever they reach a done
14-
state, and/or after a predifined number of steps.
14+
state, and/or after a predefined number of steps.
1515

1616
Because data collection is a potentially compute heavy process, it is crucial to
1717
configure the execution hyperparameters appropriately.
@@ -21,7 +21,7 @@ class will execute the data collection on the training worker. The :obj:`MultiSy
2121
will split the workload across an number of workers and aggregate the results that
2222
will be delivered to the training worker. Finally, the :obj:`MultiaSyncDataCollector` will
2323
execute the data collection on several workers and deliver the first batch of results
24-
that it can gather. This execution will occur continuously and concomittantly with
24+
that it can gather. This execution will occur continuously and concomitantly with
2525
the training of the networks: this implies that the weights of the policy that
2626
is used for the data collection may slightly lag the configuration of the policy
2727
on the training worker. Therefore, although this class may be the fastest to collect
@@ -35,7 +35,7 @@ by setting `update_at_each_batch=True` in the constructor.
3535
The second parameter to consider (in the remote settings) is the device where the
3636
data will be collected and the device where the environment and policy operations
3737
will be executed. For instance, a policy executed on CPU may be slower than one
38-
executed on CUDA. When multiple inference workers run concomittantly, dispatching
38+
executed on CUDA. When multiple inference workers run concomitantly, dispatching
3939
the compute workload across the available devices may speed up the collection or
4040
avoid OOM errors. Finally, the choice of the batch size and passing device (ie the
4141
device where the data will be stored while waiting to be passed to the collection
@@ -58,8 +58,8 @@ Besides those compute parameters, users may choose to configure the following pa
5858
- reset_when_done: whether environments should be reset when reaching a done state.
5959

6060

61-
Data collectors
62-
---------------
61+
Single node data collectors
62+
---------------------------
6363
.. currentmodule:: torchrl.collectors.collectors
6464

6565
.. autosummary::
@@ -73,6 +73,50 @@ Data collectors
7373
aSyncDataCollector
7474

7575

76+
Distributed data collectors
77+
---------------------------
78+
TorchRL provides a set of distributed data collectors. These tools support
79+
multiple backends (``'gloo'``, ``'nccl'``, ``'mpi'`` with the :class:`~.DistributedDataCollector`
80+
or PyTorch RPC with :class:`~.RPCDataCollector`) and launchers (``'ray'``,
81+
``submitit`` or ``torch.multiprocessing``).
82+
They can be efficiently used in synchronous or asynchronous mode, on a single
83+
node or across multiple nodes.
84+
85+
*Resources*: Find examples for these collectors in the
86+
`dedicated folder <https://github.com/pytorch/rl/examples/distributed/collectors>`_.
87+
88+
.. note::
89+
*Choosing the sub-collector*: All distributed collectors support the various single machine collectors.
90+
One may wonder why using a :class:`MultiSyncDataCollector` or a :class:`torchrl.envs.ParallelEnv`
91+
instead. In general, multiprocessed collectors have a lower IO footprint than
92+
parallel environments which need to communicate at each step. Yet, the model specs
93+
play a role in the opposite direction, since using parallel environments will
94+
result in a faster execution of the policy (and/or transforms) since these
95+
operations will be vectorized.
96+
97+
.. note::
98+
*Choosing the device of a collector (or a parallel environment)*: Sharing data
99+
among processes is achieved via shared-memory buffers with parallel environment
100+
and multiprocessed environments executed on CPU. Depending on the capabilities
101+
of the machine being used, this may be prohibitively slow compared to sharing
102+
data on GPU which is natively supported by cuda drivers.
103+
In practice, this means that using the ``device="cpu"`` keyword argument when
104+
building a parallel environment or collector can result in a slower collection
105+
than using ``device="cuda"`` when available.
106+
107+
108+
.. currentmodule:: torchrl.collectors.distributed
109+
110+
.. autosummary::
111+
:toctree: generated/
112+
:template: rl_template.rst
113+
114+
DistributedDataCollector
115+
RPCDataCollector
116+
DistributedSyncDataCollector
117+
submitit_delayed_launcher
118+
119+
76120
Helper functions
77121
----------------
78122

Diff for: docs/source/reference/trainers.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,10 @@ Loggers
234234
:template: rl_template_fun.rst
235235

236236
Logger
237-
CSVLogger
238-
MLFlowLogger
239-
TensorboardLogger
240-
WandbLogger
237+
csv.CSVLogger
238+
mlflow.MLFlowLogger
239+
tensorboard.TensorboardLogger
240+
wandb.WandbLogger
241241
get_logger
242242
generate_exp_name
243243

Diff for: examples/distributed/collectors/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Distributed data collection examples
2+
3+
If your algorithm is bound by the data collection speed, you may consider using
4+
distributed data collector to make your training faster.
5+
TorchRL offers a bunch of distributed data collectors that you can use
6+
to increase the collection speed tenfold or more.
7+
8+
These examples are divided in a single machine and a multi-node series.
9+
10+
Refer to the [documentation](https://pytorch.org/rl/reference/collectors.html)
11+
for more insight on what you can expect do
12+
and how these tools should be used.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Multi-node distributed data collection with submitit in contexts where jobs can't launch other jobs.
6+
7+
The default configuration will ask for 8 nodes with 1 GPU each and 32 procs / node.
8+
9+
It should reach a collection speed of roughly 15-25K fps, or better depending
10+
on the cluster specs.
11+
12+
The logic of the script is the following: we create a `main()` function that
13+
executes or code (in this case just a data collection but in practice a training
14+
loop should be present).
15+
16+
Since this `main()` function cannot launch sub-jobs by design, we launch the script
17+
from the jump host and pass the slurm specs to submitit.
18+
19+
*Note*:
20+
21+
Although we don't go in much details into this in this script, the specs of the training
22+
node and the specs of the inference nodes can differ (look at the DEFAULT_SLURM_CONF
23+
and DEFAULT_SLURM_CONF_MAIN dictionaries below).
24+
25+
"""
26+
import time
27+
from argparse import ArgumentParser
28+
29+
import tqdm
30+
from torchrl.collectors.distributed import submitit_delayed_launcher
31+
32+
from torchrl.collectors.distributed.default_configs import (
33+
DEFAULT_SLURM_CONF,
34+
DEFAULT_SLURM_CONF_MAIN,
35+
)
36+
from torchrl.collectors.distributed.generic import DistributedDataCollector
37+
from torchrl.envs import EnvCreator
38+
39+
parser = ArgumentParser()
40+
parser.add_argument("--partition", "-p", help="slurm partition to use")
41+
parser.add_argument("--num_jobs", type=int, default=8, help="Number of jobs")
42+
parser.add_argument("--tcp_port", type=int, default=1234, help="TCP port")
43+
parser.add_argument(
44+
"--num_workers", type=int, default=8, help="Number of workers per node"
45+
)
46+
parser.add_argument(
47+
"--gpus_per_node",
48+
"--gpus-per-node",
49+
"-G",
50+
type=int,
51+
default=1,
52+
help="Number of GPUs per node. If greater than 0, the backend used will be NCCL.",
53+
)
54+
parser.add_argument(
55+
"--cpus_per_task",
56+
"--cpus-per-task",
57+
"-c",
58+
type=int,
59+
default=32,
60+
help="Number of CPUs per node.",
61+
)
62+
parser.add_argument(
63+
"--sync", action="store_true", help="Use --sync to collect data synchronously."
64+
)
65+
parser.add_argument(
66+
"--frames_per_batch",
67+
"--frames-per-batch",
68+
default=4000,
69+
type=int,
70+
help="Number of frames in each batch of data. Must be "
71+
"divisible by the product of nodes and workers if sync, by the number of "
72+
"workers otherwise.",
73+
)
74+
parser.add_argument(
75+
"--total_frames",
76+
"--total-frames",
77+
default=10_000_000,
78+
type=int,
79+
help="Total number of frames collected by the collector.",
80+
)
81+
parser.add_argument(
82+
"--time",
83+
"-t",
84+
default="1:00:00",
85+
help="Timeout for the nodes",
86+
)
87+
88+
args = parser.parse_args()
89+
90+
slurm_gpus_per_node = args.gpus_per_node
91+
slurm_time = args.time
92+
93+
DEFAULT_SLURM_CONF["slurm_gpus_per_node"] = slurm_gpus_per_node
94+
DEFAULT_SLURM_CONF["slurm_time"] = slurm_time
95+
DEFAULT_SLURM_CONF["slurm_cpus_per_task"] = args.cpus_per_task
96+
DEFAULT_SLURM_CONF["slurm_partition"] = args.partition
97+
DEFAULT_SLURM_CONF_MAIN["slurm_partition"] = args.partition
98+
DEFAULT_SLURM_CONF_MAIN["slurm_time"] = slurm_time
99+
100+
num_jobs = args.num_jobs
101+
tcp_port = args.tcp_port
102+
num_workers = args.num_workers
103+
sync = args.sync
104+
total_frames = args.total_frames
105+
frames_per_batch = args.frames_per_batch
106+
107+
108+
@submitit_delayed_launcher(
109+
num_jobs=num_jobs,
110+
backend="nccl" if slurm_gpus_per_node else "gloo",
111+
tcpport=tcp_port,
112+
)
113+
def main():
114+
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
115+
from torchrl.collectors.collectors import RandomPolicy
116+
from torchrl.data import BoundedTensorSpec
117+
from torchrl.envs.libs.gym import GymEnv
118+
119+
collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
120+
device_str = "device" if num_workers == 1 else "devices"
121+
collector = DistributedDataCollector(
122+
[EnvCreator(lambda: GymEnv("ALE/Pong-v5"))] * num_jobs,
123+
policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
124+
launcher="submitit_delayed",
125+
frames_per_batch=frames_per_batch,
126+
total_frames=total_frames,
127+
tcp_port=tcp_port,
128+
collector_class=collector_class,
129+
num_workers_per_collector=args.num_workers,
130+
collector_kwargs={device_str: "cuda:0" if slurm_gpus_per_node else "cpu"},
131+
storing_device="cuda:0" if slurm_gpus_per_node else "cpu",
132+
backend="nccl" if slurm_gpus_per_node else "gloo",
133+
sync=sync,
134+
)
135+
counter = 0
136+
pbar = tqdm.tqdm(total=collector.total_frames)
137+
for i, data in enumerate(collector):
138+
pbar.update(data.numel())
139+
pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
140+
if i >= 10:
141+
counter += data.numel()
142+
if i == 10:
143+
t0 = time.time()
144+
t1 = time.time()
145+
print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
146+
collector.shutdown()
147+
exit()
148+
149+
150+
if __name__ == "__main__":
151+
main()

0 commit comments

Comments
 (0)