Skip to content

Commit 31af2c5

Browse files
committed
[Feature] pass policy-factory in mp data collectors
ghstack-source-id: bce8abe Pull Request resolved: #2859
1 parent 595ddb4 commit 31af2c5

File tree

9 files changed

+336
-64
lines changed

9 files changed

+336
-64
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Updating MPS weights in multiprocess/distributed data collectors
8+
================================================================
9+
10+
Overview of the Script
11+
----------------------
12+
13+
This script demonstrates a weight update in TorchRL.
14+
The script uses a custom `MPSRemoteWeightUpdater` class to update the weights of a policy network across multiple workers.
15+
16+
Key Features
17+
------------
18+
19+
- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
20+
("Pendulum-v1") using a policy network.
21+
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
22+
- Custom Weight Updater: The `MPSRemoteWeightUpdater` class is used to update the policy weights across workers. This
23+
class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.
24+
25+
Workaround for MPS Tensor Serialization Issue
26+
---------------------------------------------
27+
28+
In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
29+
processes. To work around this issue, the MPSRemoteWeightUpdater class sends the policy weights on the CPU device
30+
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.
31+
32+
Script Flow
33+
-----------
34+
35+
1. Initialize the environment, policy network, and collector.
36+
2. Update the policy weights using the MPSRemoteWeightUpdater.
37+
3. Collect data from the environment using the policy network.
38+
4. Zero out the policy weights after a few iterations.
39+
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.
40+
41+
"""
42+
43+
import tensordict
44+
import torch
45+
from tensordict import TensorDictBase
46+
from tensordict.nn import TensorDictModule
47+
from torch import nn
48+
from torchrl.collectors import MultiSyncDataCollector, RemoteWeightUpdaterBase
49+
50+
from torchrl.envs.libs.gym import GymEnv
51+
52+
53+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
54+
def __init__(self, policy_weights, num_workers):
55+
# Weights are on mps device, which cannot be shared
56+
self.policy_weights = policy_weights.data
57+
self.num_workers = num_workers
58+
59+
def _sync_weights_with_worker(
60+
self, worker_id: int | torch.device, server_weights: TensorDictBase
61+
) -> TensorDictBase:
62+
# Send weights on cpu - the local workers will do the cpu->mps copy
63+
self.collector.pipes[worker_id].send((server_weights, "update"))
64+
val, msg = self.collector.pipes[worker_id].recv()
65+
assert msg == "updated"
66+
return server_weights
67+
68+
def _get_server_weights(self) -> TensorDictBase:
69+
print((self.policy_weights == 0).all())
70+
return self.policy_weights.cpu()
71+
72+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
73+
print((server_weights == 0).all())
74+
return server_weights
75+
76+
def all_worker_ids(self) -> list[int] | list[torch.device]:
77+
return list(range(self.num_workers))
78+
79+
80+
if __name__ == "__main__":
81+
device = "mps"
82+
83+
def env_maker():
84+
return GymEnv("Pendulum-v1", device="cpu")
85+
86+
def policy_factory(device=device):
87+
return TensorDictModule(
88+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
89+
).to(device=device)
90+
91+
policy = policy_factory()
92+
policy_weights = tensordict.from_module(policy)
93+
94+
collector = MultiSyncDataCollector(
95+
create_env_fn=[env_maker, env_maker],
96+
policy_factory=policy_factory,
97+
total_frames=2000,
98+
max_frames_per_traj=50,
99+
frames_per_batch=200,
100+
init_random_frames=-1,
101+
reset_at_each_iter=False,
102+
device=device,
103+
storing_device="cpu",
104+
remote_weight_updater=MPSRemoteWeightUpdater(policy_weights, 2),
105+
# use_buffers=False,
106+
# cat_results="stack",
107+
)
108+
109+
collector.update_policy_weights_()
110+
try:
111+
for i, data in enumerate(collector):
112+
if i == 2:
113+
print(data)
114+
assert (data["action"] != 0).any()
115+
# zero the policy
116+
policy_weights.data.zero_()
117+
collector.update_policy_weights_()
118+
elif i == 3:
119+
assert (data["action"] == 0).all(), data["action"]
120+
break
121+
finally:
122+
collector.shutdown()

test/test_collector.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
prod,
4040
seed_generator,
4141
)
42-
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
42+
from torchrl.collectors import (
43+
aSyncDataCollector,
44+
RemoteWeightUpdaterBase,
45+
SyncDataCollector,
46+
)
4347
from torchrl.collectors.collectors import (
4448
_Interruptor,
4549
MultiaSyncDataCollector,
@@ -146,6 +150,7 @@
146150
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
147151
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
148152
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
153+
_has_cuda = torch.cuda.is_available()
149154

150155

151156
class WrappablePolicy(nn.Module):
@@ -3476,6 +3481,69 @@ def __deepcopy_error__(*args, **kwargs):
34763481
raise RuntimeError("deepcopy not allowed")
34773482

34783483

3484+
class TestPolicyFactory:
3485+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
3486+
def __init__(self, policy_weights, num_workers):
3487+
# Weights are on mps device, which cannot be shared
3488+
self.policy_weights = policy_weights.data
3489+
self.num_workers = num_workers
3490+
3491+
def _sync_weights_with_worker(
3492+
self, worker_id: int | torch.device, server_weights: TensorDictBase
3493+
) -> TensorDictBase:
3494+
# Send weights on cpu - the local workers will do the cpu->mps copy
3495+
self.collector.pipes[worker_id].send((server_weights, "update"))
3496+
val, msg = self.collector.pipes[worker_id].recv()
3497+
assert msg == "updated"
3498+
return server_weights
3499+
3500+
def _get_server_weights(self) -> TensorDictBase:
3501+
return self.policy_weights.cpu()
3502+
3503+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
3504+
return server_weights
3505+
3506+
def all_worker_ids(self) -> list[int] | list[torch.device]:
3507+
return list(range(self.num_workers))
3508+
3509+
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
3510+
def test_weight_update(self):
3511+
device = "cuda:0"
3512+
env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
3513+
policy_factory = lambda: TensorDictModule(
3514+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
3515+
).to(device)
3516+
policy = policy_factory()
3517+
policy_weights = TensorDict.from_module(policy)
3518+
3519+
collector = MultiSyncDataCollector(
3520+
create_env_fn=[env_maker, env_maker],
3521+
policy_factory=policy_factory,
3522+
total_frames=2000,
3523+
max_frames_per_traj=50,
3524+
frames_per_batch=200,
3525+
init_random_frames=-1,
3526+
reset_at_each_iter=False,
3527+
device=device,
3528+
storing_device="cpu",
3529+
remote_weight_updater=self.MPSRemoteWeightUpdater(policy_weights, 2),
3530+
)
3531+
3532+
collector.update_policy_weights_()
3533+
try:
3534+
for i, data in enumerate(collector):
3535+
if i == 2:
3536+
assert (data["action"] != 0).any()
3537+
# zero the policy
3538+
policy_weights.data.zero_()
3539+
collector.update_policy_weights_()
3540+
elif i == 3:
3541+
assert (data["action"] == 0).all(), data["action"]
3542+
break
3543+
finally:
3544+
collector.shutdown()
3545+
3546+
34793547
if __name__ == "__main__":
34803548
args, unknown = argparse.ArgumentParser().parse_known_args()
34813549
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)