Skip to content

Commit 6a7dbcf

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents de3dac3 + a901064 commit 6a7dbcf

24 files changed

+350
-10
lines changed

.github/unittest/linux_sota/scripts/run_all.sh

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ export SDL_VIDEODRIVER=dummy
7575
export MUJOCO_GL=egl
7676
export PYOPENGL_PLATFORM=egl
7777
export LAZY_LEGACY_OP=False
78+
export COMPOSITE_LP_AGGREGATE=0
7879

7980
conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \
8081
DISPLAY=unix:0.0 \

.github/unittest/linux_sota/scripts/test_sota.py

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from pathlib import Path
88

99
import pytest
10+
from tensordict.nn import composite_lp_aggregate
11+
12+
# Check that we're using the new behavior
13+
assert (
14+
not composite_lp_aggregate()
15+
), "Composite LP must be set to False. Run this test with COMPOSITE_LP_AGGREGATE=0"
1016

1117
commands = {
1218
"dt": """python sota-implementations/decision_transformer/dt.py \

.github/workflows/benchmarks.yml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99
workflow_dispatch:
1010

1111
permissions:
12+
id-token: write
1213
deployments: write
1314
contents: write
1415

.github/workflows/docs.yml

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ concurrency:
1818
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1919
cancel-in-progress: true
2020

21+
permissions:
22+
id-token: write
23+
contents: read
24+
2125
jobs:
2226
build-docs:
2327
strategy:

.github/workflows/lint.yml

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ concurrency:
1515
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1616
cancel-in-progress: true
1717

18+
permissions:
19+
id-token: write
20+
contents: read
21+
1822
jobs:
1923
python-source-and-configs:
2024
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

.github/workflows/nightly_build.yml

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ concurrency:
2727
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
2828
cancel-in-progress: true
2929

30+
permissions:
31+
id-token: write
32+
contents: read
33+
3034
jobs:
3135
build-wheel-linux:
3236
# Don't run on forked repos.

.github/workflows/test-linux-habitat.yml

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ concurrency:
1515
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1616
cancel-in-progress: true
1717

18+
permissions:
19+
id-token: write
20+
contents: read
21+
1822
jobs:
1923
tests:
2024
strategy:

.github/workflows/test-linux-libs.yml

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ concurrency:
1515
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1616
cancel-in-progress: true
1717

18+
permissions:
19+
id-token: write
20+
contents: read
21+
1822
jobs:
1923

2024
unittests-atari-dqn:

.github/workflows/test-linux-rlhf.yml

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ concurrency:
1515
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1616
cancel-in-progress: true
1717

18+
permissions:
19+
id-token: write
20+
contents: read
21+
1822
jobs:
1923
unittests:
2024
strategy:

.github/workflows/test-linux-sota.yml

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ concurrency:
1818
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1919
cancel-in-progress: true
2020

21+
permissions:
22+
id-token: write
23+
contents: read
24+
2125
jobs:
2226
tests:
2327
strategy:

.github/workflows/test-linux.yml

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ concurrency:
1818
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1919
cancel-in-progress: true
2020

21+
permissions:
22+
id-token: write
23+
contents: read
24+
2125
jobs:
2226
tests-cpu:
2327
strategy:

.github/workflows/test-windows-optdepts.yml

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ concurrency:
1515
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1616
cancel-in-progress: true
1717

18+
permissions:
19+
id-token: write
20+
contents: read
21+
1822
jobs:
1923
unittests-cpu:
2024
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main

.github/workflows/wheels-legacy.yml

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ concurrency:
1313
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
1414
cancel-in-progress: true
1515

16+
permissions:
17+
id-token: write
18+
contents: read
19+
1620
jobs:
1721

1822
build-wheel-windows:

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ to be able to create this other composition:
848848
SelectTransform
849849
SignTransform
850850
SqueezeTransform
851+
Stack
851852
StepCounter
852853
TargetReturn
853854
TensorDictPrimer
+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Using the SyncDataCollector with Different Device Combinations
8+
==============================================================
9+
10+
TorchRL's SyncDataCollector allows you to specify the devices on which different components of the data collection
11+
process are executed. This example demonstrates how to use the collector with various device combinations.
12+
13+
14+
Understanding Device Precedence
15+
-------------------------------
16+
17+
When creating a SyncDataCollector, you can specify the devices for the environment (env_device), policy (policy_device),
18+
and data collection (device). The device argument serves as a default value for any unspecified devices. However, if you
19+
provide env_device or policy_device, they take precedence over the device argument for their respective components.
20+
21+
For example:
22+
23+
- If you set device="cuda", all components will be executed on the CUDA device unless you specify otherwise.
24+
- If you set env_device="cpu" and device="cuda", the environment will be executed on the CPU, while the policy and data
25+
collection will be executed on the CUDA device.
26+
27+
Keeping Policy Parameters in Sync
28+
---------------------------------
29+
30+
When using a policy with buffers or other attributes that are not automatically updated when moving the policy's
31+
parameters to a different device, it's essential to keep the policy's parameters in sync between the main workspace and
32+
the collector.
33+
34+
To do this, call update_policy_weights_() anytime the policy's parameters (and buffers!) are updated. This ensures that
35+
the policy used by the collector has the same parameters as the policy in the main workspace.
36+
37+
Example Use Cases
38+
-----------------
39+
40+
This script demonstrates the SyncDataCollector with the following device combinations:
41+
42+
- Collector on CUDA
43+
- Collector on CPU
44+
- Mixed collector: policy on CUDA, env untouched (ie, unmarked CPU, env.device == None)
45+
- Mixed collector: policy on CUDA, env on CPU (env.device == "cpu")
46+
- Mixed collector: all on CUDA, except env on CPU.
47+
48+
For each configuration, we run a DQN algorithm and check that it converges.
49+
By following this example, you can learn how to use the SyncDataCollector with different device combinations and ensure
50+
that your policy's parameters are kept in sync.
51+
52+
"""
53+
54+
import logging
55+
import time
56+
57+
import torch.cuda
58+
import torch.nn as nn
59+
import torch.optim as optim
60+
61+
from tensordict.nn import TensorDictSequential as TDSeq
62+
63+
from torchrl.collectors import SyncDataCollector
64+
from torchrl.data import LazyTensorStorage, ReplayBuffer
65+
from torchrl.envs import Compose, GymEnv, RewardSum, StepCounter, TransformedEnv
66+
from torchrl.modules import EGreedyModule, QValueActor
67+
from torchrl.objectives import DQNLoss, SoftUpdate
68+
69+
70+
logging.basicConfig(level=logging.INFO)
71+
my_logger = logging.getLogger(__name__)
72+
73+
ENV_NAME = "CartPole-v1"
74+
75+
INIT_RND_STEPS = 5_120
76+
FRAMES_PER_BATCH = 128
77+
BUFFER_SIZE = 100_000
78+
79+
GAMMA = 0.98
80+
OPTIM_STEPS = 10
81+
BATCH_SIZE = 128
82+
83+
SOFTU_EPS = 0.99
84+
LR = 0.02
85+
86+
87+
class Net(nn.Module):
88+
def __init__(self, obs_size: int, n_actions: int) -> None:
89+
super().__init__()
90+
self.net = nn.Sequential(
91+
nn.Linear(obs_size, 128),
92+
nn.ReLU(),
93+
nn.Linear(128, n_actions),
94+
)
95+
96+
def forward(self, x):
97+
orig_shape_unbatched = len(x.shape) == 1
98+
if orig_shape_unbatched:
99+
x = x.unsqueeze(0)
100+
101+
out = self.net(x)
102+
103+
if orig_shape_unbatched:
104+
out = out.squeeze(0)
105+
return out
106+
107+
108+
def make_env(env_name: str):
109+
return TransformedEnv(GymEnv(env_name), Compose(StepCounter(), RewardSum()))
110+
111+
112+
if __name__ == "__main__":
113+
114+
for env_device, policy_device, device in (
115+
(None, None, "cuda"),
116+
(None, None, "cpu"),
117+
(None, "cuda", None),
118+
("cpu", "cuda", None),
119+
("cpu", None, "cuda"),
120+
# These configs don't run because the collector needs to know that the policy is on CUDA
121+
# This is not true for the env which has specs that are associated with a device, we can
122+
# automatically transfer the data. The policy does not, in general, have a spec indicating
123+
# what the input and output devices are, so this must be told to the collector.
124+
# (None, None, None),
125+
# ("cpu", None, None),
126+
):
127+
torch.manual_seed(0)
128+
torch.cuda.manual_seed(0)
129+
130+
env = make_env(ENV_NAME)
131+
env.set_seed(0)
132+
133+
n_obs = env.observation_spec["observation"].shape[-1]
134+
n_act = env.action_spec.shape[-1]
135+
136+
net = Net(n_obs, n_act).to(device="cuda:0")
137+
agent = QValueActor(net, spec=env.action_spec.to("cuda:0"))
138+
139+
# policy_explore has buffers on CPU - we will need to call collector.update_policy_weights_()
140+
# to sync them during data collection.
141+
policy_explore = EGreedyModule(env.action_spec)
142+
agent_explore = TDSeq(agent, policy_explore)
143+
144+
collector = SyncDataCollector(
145+
env,
146+
agent_explore,
147+
frames_per_batch=FRAMES_PER_BATCH,
148+
init_random_frames=INIT_RND_STEPS,
149+
device=device,
150+
env_device=env_device,
151+
policy_device=policy_device,
152+
)
153+
exp_buffer = ReplayBuffer(
154+
storage=LazyTensorStorage(BUFFER_SIZE, device="cuda:0")
155+
)
156+
157+
loss = DQNLoss(
158+
value_network=agent, action_space=env.action_spec, delay_value=True
159+
)
160+
loss.make_value_estimator(gamma=GAMMA)
161+
target_updater = SoftUpdate(loss, eps=SOFTU_EPS)
162+
optimizer = optim.Adam(loss.parameters(), lr=LR)
163+
164+
total_count = 0
165+
total_episodes = 0
166+
t0 = time.time()
167+
for i, data in enumerate(collector):
168+
# Check the data devices
169+
if device is None:
170+
assert data["action"].device == torch.device("cuda:0")
171+
assert data["observation"].device == torch.device("cpu")
172+
assert data["done"].device == torch.device("cpu")
173+
elif device == "cpu":
174+
assert data["action"].device == torch.device("cpu")
175+
assert data["observation"].device == torch.device("cpu")
176+
assert data["done"].device == torch.device("cpu")
177+
else:
178+
assert data["action"].device == torch.device("cuda:0")
179+
assert data["observation"].device == torch.device("cuda:0")
180+
assert data["done"].device == torch.device("cuda:0")
181+
182+
exp_buffer.extend(data)
183+
max_length = exp_buffer["next", "step_count"].max()
184+
max_reward = exp_buffer["next", "episode_reward"].max()
185+
if len(exp_buffer) > INIT_RND_STEPS:
186+
for _ in range(OPTIM_STEPS):
187+
optimizer.zero_grad()
188+
sample = exp_buffer.sample(batch_size=BATCH_SIZE)
189+
190+
loss_vals = loss(sample)
191+
loss_vals["loss"].backward()
192+
optimizer.step()
193+
194+
agent_explore[1].step(data.numel())
195+
target_updater.step()
196+
197+
total_count += data.numel()
198+
total_episodes += data["next", "done"].sum()
199+
200+
if i % 10 == 0:
201+
my_logger.info(
202+
f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}."
203+
)
204+
collector.update_policy_weights_()
205+
if max_length > 200:
206+
t1 = time.time()
207+
my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!")
208+
my_logger.info(f"With {max_reward} Reward!")
209+
my_logger.info(f"In {total_episodes} Episodes!")
210+
my_logger.info(f"Using devices {(env_device, policy_device, device)}")
211+
break
212+
else:
213+
raise RuntimeError(
214+
f"Failed to converge with config {(env_device, policy_device, device)}"
215+
)

0 commit comments

Comments
 (0)