|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +""" |
| 7 | +Using the SyncDataCollector with Different Device Combinations |
| 8 | +============================================================== |
| 9 | +
|
| 10 | +TorchRL's SyncDataCollector allows you to specify the devices on which different components of the data collection |
| 11 | +process are executed. This example demonstrates how to use the collector with various device combinations. |
| 12 | +
|
| 13 | +
|
| 14 | +Understanding Device Precedence |
| 15 | +------------------------------- |
| 16 | +
|
| 17 | +When creating a SyncDataCollector, you can specify the devices for the environment (env_device), policy (policy_device), |
| 18 | +and data collection (device). The device argument serves as a default value for any unspecified devices. However, if you |
| 19 | +provide env_device or policy_device, they take precedence over the device argument for their respective components. |
| 20 | +
|
| 21 | +For example: |
| 22 | +
|
| 23 | +- If you set device="cuda", all components will be executed on the CUDA device unless you specify otherwise. |
| 24 | +- If you set env_device="cpu" and device="cuda", the environment will be executed on the CPU, while the policy and data |
| 25 | + collection will be executed on the CUDA device. |
| 26 | +
|
| 27 | +Keeping Policy Parameters in Sync |
| 28 | +--------------------------------- |
| 29 | +
|
| 30 | +When using a policy with buffers or other attributes that are not automatically updated when moving the policy's |
| 31 | +parameters to a different device, it's essential to keep the policy's parameters in sync between the main workspace and |
| 32 | +the collector. |
| 33 | +
|
| 34 | +To do this, call update_policy_weights_() anytime the policy's parameters (and buffers!) are updated. This ensures that |
| 35 | +the policy used by the collector has the same parameters as the policy in the main workspace. |
| 36 | +
|
| 37 | +Example Use Cases |
| 38 | +----------------- |
| 39 | +
|
| 40 | +This script demonstrates the SyncDataCollector with the following device combinations: |
| 41 | +
|
| 42 | +- Collector on CUDA |
| 43 | +- Collector on CPU |
| 44 | +- Mixed collector: policy on CUDA, env untouched (ie, unmarked CPU, env.device == None) |
| 45 | +- Mixed collector: policy on CUDA, env on CPU (env.device == "cpu") |
| 46 | +- Mixed collector: all on CUDA, except env on CPU. |
| 47 | +
|
| 48 | +For each configuration, we run a DQN algorithm and check that it converges. |
| 49 | +By following this example, you can learn how to use the SyncDataCollector with different device combinations and ensure |
| 50 | +that your policy's parameters are kept in sync. |
| 51 | +
|
| 52 | +""" |
| 53 | + |
| 54 | +import logging |
| 55 | +import time |
| 56 | + |
| 57 | +import torch.cuda |
| 58 | +import torch.nn as nn |
| 59 | +import torch.optim as optim |
| 60 | + |
| 61 | +from tensordict.nn import TensorDictSequential as TDSeq |
| 62 | + |
| 63 | +from torchrl.collectors import SyncDataCollector |
| 64 | +from torchrl.data import LazyTensorStorage, ReplayBuffer |
| 65 | +from torchrl.envs import Compose, GymEnv, RewardSum, StepCounter, TransformedEnv |
| 66 | +from torchrl.modules import EGreedyModule, QValueActor |
| 67 | +from torchrl.objectives import DQNLoss, SoftUpdate |
| 68 | + |
| 69 | + |
| 70 | +logging.basicConfig(level=logging.INFO) |
| 71 | +my_logger = logging.getLogger(__name__) |
| 72 | + |
| 73 | +ENV_NAME = "CartPole-v1" |
| 74 | + |
| 75 | +INIT_RND_STEPS = 5_120 |
| 76 | +FRAMES_PER_BATCH = 128 |
| 77 | +BUFFER_SIZE = 100_000 |
| 78 | + |
| 79 | +GAMMA = 0.98 |
| 80 | +OPTIM_STEPS = 10 |
| 81 | +BATCH_SIZE = 128 |
| 82 | + |
| 83 | +SOFTU_EPS = 0.99 |
| 84 | +LR = 0.02 |
| 85 | + |
| 86 | + |
| 87 | +class Net(nn.Module): |
| 88 | + def __init__(self, obs_size: int, n_actions: int) -> None: |
| 89 | + super().__init__() |
| 90 | + self.net = nn.Sequential( |
| 91 | + nn.Linear(obs_size, 128), |
| 92 | + nn.ReLU(), |
| 93 | + nn.Linear(128, n_actions), |
| 94 | + ) |
| 95 | + |
| 96 | + def forward(self, x): |
| 97 | + orig_shape_unbatched = len(x.shape) == 1 |
| 98 | + if orig_shape_unbatched: |
| 99 | + x = x.unsqueeze(0) |
| 100 | + |
| 101 | + out = self.net(x) |
| 102 | + |
| 103 | + if orig_shape_unbatched: |
| 104 | + out = out.squeeze(0) |
| 105 | + return out |
| 106 | + |
| 107 | + |
| 108 | +def make_env(env_name: str): |
| 109 | + return TransformedEnv(GymEnv(env_name), Compose(StepCounter(), RewardSum())) |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == "__main__": |
| 113 | + |
| 114 | + for env_device, policy_device, device in ( |
| 115 | + (None, None, "cuda"), |
| 116 | + (None, None, "cpu"), |
| 117 | + (None, "cuda", None), |
| 118 | + ("cpu", "cuda", None), |
| 119 | + ("cpu", None, "cuda"), |
| 120 | + # These configs don't run because the collector needs to know that the policy is on CUDA |
| 121 | + # This is not true for the env which has specs that are associated with a device, we can |
| 122 | + # automatically transfer the data. The policy does not, in general, have a spec indicating |
| 123 | + # what the input and output devices are, so this must be told to the collector. |
| 124 | + # (None, None, None), |
| 125 | + # ("cpu", None, None), |
| 126 | + ): |
| 127 | + torch.manual_seed(0) |
| 128 | + torch.cuda.manual_seed(0) |
| 129 | + |
| 130 | + env = make_env(ENV_NAME) |
| 131 | + env.set_seed(0) |
| 132 | + |
| 133 | + n_obs = env.observation_spec["observation"].shape[-1] |
| 134 | + n_act = env.action_spec.shape[-1] |
| 135 | + |
| 136 | + net = Net(n_obs, n_act).to(device="cuda:0") |
| 137 | + agent = QValueActor(net, spec=env.action_spec.to("cuda:0")) |
| 138 | + |
| 139 | + # policy_explore has buffers on CPU - we will need to call collector.update_policy_weights_() |
| 140 | + # to sync them during data collection. |
| 141 | + policy_explore = EGreedyModule(env.action_spec) |
| 142 | + agent_explore = TDSeq(agent, policy_explore) |
| 143 | + |
| 144 | + collector = SyncDataCollector( |
| 145 | + env, |
| 146 | + agent_explore, |
| 147 | + frames_per_batch=FRAMES_PER_BATCH, |
| 148 | + init_random_frames=INIT_RND_STEPS, |
| 149 | + device=device, |
| 150 | + env_device=env_device, |
| 151 | + policy_device=policy_device, |
| 152 | + ) |
| 153 | + exp_buffer = ReplayBuffer( |
| 154 | + storage=LazyTensorStorage(BUFFER_SIZE, device="cuda:0") |
| 155 | + ) |
| 156 | + |
| 157 | + loss = DQNLoss( |
| 158 | + value_network=agent, action_space=env.action_spec, delay_value=True |
| 159 | + ) |
| 160 | + loss.make_value_estimator(gamma=GAMMA) |
| 161 | + target_updater = SoftUpdate(loss, eps=SOFTU_EPS) |
| 162 | + optimizer = optim.Adam(loss.parameters(), lr=LR) |
| 163 | + |
| 164 | + total_count = 0 |
| 165 | + total_episodes = 0 |
| 166 | + t0 = time.time() |
| 167 | + for i, data in enumerate(collector): |
| 168 | + # Check the data devices |
| 169 | + if device is None: |
| 170 | + assert data["action"].device == torch.device("cuda:0") |
| 171 | + assert data["observation"].device == torch.device("cpu") |
| 172 | + assert data["done"].device == torch.device("cpu") |
| 173 | + elif device == "cpu": |
| 174 | + assert data["action"].device == torch.device("cpu") |
| 175 | + assert data["observation"].device == torch.device("cpu") |
| 176 | + assert data["done"].device == torch.device("cpu") |
| 177 | + else: |
| 178 | + assert data["action"].device == torch.device("cuda:0") |
| 179 | + assert data["observation"].device == torch.device("cuda:0") |
| 180 | + assert data["done"].device == torch.device("cuda:0") |
| 181 | + |
| 182 | + exp_buffer.extend(data) |
| 183 | + max_length = exp_buffer["next", "step_count"].max() |
| 184 | + max_reward = exp_buffer["next", "episode_reward"].max() |
| 185 | + if len(exp_buffer) > INIT_RND_STEPS: |
| 186 | + for _ in range(OPTIM_STEPS): |
| 187 | + optimizer.zero_grad() |
| 188 | + sample = exp_buffer.sample(batch_size=BATCH_SIZE) |
| 189 | + |
| 190 | + loss_vals = loss(sample) |
| 191 | + loss_vals["loss"].backward() |
| 192 | + optimizer.step() |
| 193 | + |
| 194 | + agent_explore[1].step(data.numel()) |
| 195 | + target_updater.step() |
| 196 | + |
| 197 | + total_count += data.numel() |
| 198 | + total_episodes += data["next", "done"].sum() |
| 199 | + |
| 200 | + if i % 10 == 0: |
| 201 | + my_logger.info( |
| 202 | + f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}." |
| 203 | + ) |
| 204 | + collector.update_policy_weights_() |
| 205 | + if max_length > 200: |
| 206 | + t1 = time.time() |
| 207 | + my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!") |
| 208 | + my_logger.info(f"With {max_reward} Reward!") |
| 209 | + my_logger.info(f"In {total_episodes} Episodes!") |
| 210 | + my_logger.info(f"Using devices {(env_device, policy_device, device)}") |
| 211 | + break |
| 212 | + else: |
| 213 | + raise RuntimeError( |
| 214 | + f"Failed to converge with config {(env_device, policy_device, device)}" |
| 215 | + ) |
0 commit comments