|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +"""Multi-node distributed data collection with submitit in contexts where jobs can't launch other jobs. |
| 6 | +
|
| 7 | +The default configuration will ask for 8 nodes with 1 GPU each and 32 procs / node. |
| 8 | +
|
| 9 | +It should reach a collection speed of roughly 15-25K fps, or better depending |
| 10 | +on the cluster specs. |
| 11 | +
|
| 12 | +The logic of the script is the following: we create a `main()` function that |
| 13 | +executes or code (in this case just a data collection but in practice a training |
| 14 | +loop should be present). |
| 15 | +
|
| 16 | +Since this `main()` function cannot launch sub-jobs by design, we launch the script |
| 17 | +from the jump host and pass the slurm specs to submitit. |
| 18 | +
|
| 19 | +*Note*: |
| 20 | +
|
| 21 | + Although we don't go in much details into this in this script, the specs of the training |
| 22 | + node and the specs of the inference nodes can differ (look at the DEFAULT_SLURM_CONF |
| 23 | + and DEFAULT_SLURM_CONF_MAIN dictionaries below). |
| 24 | +
|
| 25 | +""" |
| 26 | +import time |
| 27 | +from argparse import ArgumentParser |
| 28 | + |
| 29 | +import tqdm |
| 30 | +from torchrl.collectors.distributed import submitit_delayed_launcher |
| 31 | + |
| 32 | +from torchrl.collectors.distributed.default_configs import ( |
| 33 | + DEFAULT_SLURM_CONF, |
| 34 | + DEFAULT_SLURM_CONF_MAIN, |
| 35 | +) |
| 36 | +from torchrl.collectors.distributed.generic import DistributedDataCollector |
| 37 | +from torchrl.envs import EnvCreator |
| 38 | + |
| 39 | +parser = ArgumentParser() |
| 40 | +parser.add_argument("--partition", "-p", help="slurm partition to use") |
| 41 | +parser.add_argument("--num_jobs", type=int, default=8, help="Number of jobs") |
| 42 | +parser.add_argument("--tcp_port", type=int, default=1234, help="TCP port") |
| 43 | +parser.add_argument( |
| 44 | + "--num_workers", type=int, default=8, help="Number of workers per node" |
| 45 | +) |
| 46 | +parser.add_argument( |
| 47 | + "--gpus_per_node", |
| 48 | + "--gpus-per-node", |
| 49 | + "-G", |
| 50 | + type=int, |
| 51 | + default=1, |
| 52 | + help="Number of GPUs per node. If greater than 0, the backend used will be NCCL.", |
| 53 | +) |
| 54 | +parser.add_argument( |
| 55 | + "--cpus_per_task", |
| 56 | + "--cpus-per-task", |
| 57 | + "-c", |
| 58 | + type=int, |
| 59 | + default=32, |
| 60 | + help="Number of CPUs per node.", |
| 61 | +) |
| 62 | +parser.add_argument( |
| 63 | + "--sync", action="store_true", help="Use --sync to collect data synchronously." |
| 64 | +) |
| 65 | +parser.add_argument( |
| 66 | + "--frames_per_batch", |
| 67 | + "--frames-per-batch", |
| 68 | + default=4000, |
| 69 | + type=int, |
| 70 | + help="Number of frames in each batch of data. Must be " |
| 71 | + "divisible by the product of nodes and workers if sync, by the number of " |
| 72 | + "workers otherwise.", |
| 73 | +) |
| 74 | +parser.add_argument( |
| 75 | + "--total_frames", |
| 76 | + "--total-frames", |
| 77 | + default=10_000_000, |
| 78 | + type=int, |
| 79 | + help="Total number of frames collected by the collector.", |
| 80 | +) |
| 81 | +parser.add_argument( |
| 82 | + "--time", |
| 83 | + "-t", |
| 84 | + default="1:00:00", |
| 85 | + help="Timeout for the nodes", |
| 86 | +) |
| 87 | + |
| 88 | +args = parser.parse_args() |
| 89 | + |
| 90 | +slurm_gpus_per_node = args.gpus_per_node |
| 91 | +slurm_time = args.time |
| 92 | + |
| 93 | +DEFAULT_SLURM_CONF["slurm_gpus_per_node"] = slurm_gpus_per_node |
| 94 | +DEFAULT_SLURM_CONF["slurm_time"] = slurm_time |
| 95 | +DEFAULT_SLURM_CONF["slurm_cpus_per_task"] = args.cpus_per_task |
| 96 | +DEFAULT_SLURM_CONF["slurm_partition"] = args.partition |
| 97 | +DEFAULT_SLURM_CONF_MAIN["slurm_partition"] = args.partition |
| 98 | +DEFAULT_SLURM_CONF_MAIN["slurm_time"] = slurm_time |
| 99 | + |
| 100 | +num_jobs = args.num_jobs |
| 101 | +tcp_port = args.tcp_port |
| 102 | +num_workers = args.num_workers |
| 103 | +sync = args.sync |
| 104 | +total_frames = args.total_frames |
| 105 | +frames_per_batch = args.frames_per_batch |
| 106 | + |
| 107 | + |
| 108 | +@submitit_delayed_launcher( |
| 109 | + num_jobs=num_jobs, |
| 110 | + backend="nccl" if slurm_gpus_per_node else "gloo", |
| 111 | + tcpport=tcp_port, |
| 112 | +) |
| 113 | +def main(): |
| 114 | + from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector |
| 115 | + from torchrl.collectors.collectors import RandomPolicy |
| 116 | + from torchrl.data import BoundedTensorSpec |
| 117 | + from torchrl.envs.libs.gym import GymEnv |
| 118 | + |
| 119 | + collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector |
| 120 | + device_str = "device" if num_workers == 1 else "devices" |
| 121 | + collector = DistributedDataCollector( |
| 122 | + [EnvCreator(lambda: GymEnv("ALE/Pong-v5"))] * num_jobs, |
| 123 | + policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), |
| 124 | + launcher="submitit_delayed", |
| 125 | + frames_per_batch=frames_per_batch, |
| 126 | + total_frames=total_frames, |
| 127 | + tcp_port=tcp_port, |
| 128 | + collector_class=collector_class, |
| 129 | + num_workers_per_collector=args.num_workers, |
| 130 | + collector_kwargs={device_str: "cuda:0" if slurm_gpus_per_node else "cpu"}, |
| 131 | + storing_device="cuda:0" if slurm_gpus_per_node else "cpu", |
| 132 | + backend="nccl" if slurm_gpus_per_node else "gloo", |
| 133 | + sync=sync, |
| 134 | + ) |
| 135 | + counter = 0 |
| 136 | + pbar = tqdm.tqdm(total=collector.total_frames) |
| 137 | + for i, data in enumerate(collector): |
| 138 | + pbar.update(data.numel()) |
| 139 | + pbar.set_description(f"data shape: {data.shape}, data device: {data.device}") |
| 140 | + if i >= 10: |
| 141 | + counter += data.numel() |
| 142 | + if i == 10: |
| 143 | + t0 = time.time() |
| 144 | + t1 = time.time() |
| 145 | + print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") |
| 146 | + collector.shutdown() |
| 147 | + exit() |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + main() |
0 commit comments