Skip to content

Commit 667466c

Browse files
committed
[Feature] pass policy-factory in mp data collectors
ghstack-source-id: 369e690 Pull Request resolved: #2859
1 parent 774dbeb commit 667466c

File tree

5 files changed

+231
-5
lines changed

5 files changed

+231
-5
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Updating MPS weights in multiprocess/distributed data collectors
8+
================================================================
9+
10+
Overview of the Script
11+
----------------------
12+
13+
This script demonstrates a weight update in TorchRL.
14+
The script uses a custom `MPSRemoteWeightUpdater` class to update the weights of a policy network across multiple workers.
15+
16+
Key Features
17+
------------
18+
19+
- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
20+
("Pendulum-v1") using a policy network.
21+
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
22+
- Custom Weight Updater: The `MPSRemoteWeightUpdater` class is used to update the policy weights across workers. This
23+
class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.
24+
25+
Workaround for MPS Tensor Serialization Issue
26+
---------------------------------------------
27+
28+
In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
29+
processes. To work around this issue, the MPSRemoteWeightUpdater class sends the policy weights on the CPU device
30+
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.
31+
32+
Script Flow
33+
-----------
34+
35+
1. Initialize the environment, policy network, and collector.
36+
2. Update the policy weights using the MPSRemoteWeightUpdater.
37+
3. Collect data from the environment using the policy network.
38+
4. Zero out the policy weights after a few iterations.
39+
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.
40+
41+
"""
42+
43+
import tensordict
44+
import torch
45+
from tensordict import TensorDictBase
46+
from tensordict.nn import TensorDictModule
47+
from torch import nn
48+
from torchrl.collectors import MultiSyncDataCollector, RemoteWeightUpdaterBase
49+
50+
from torchrl.envs.libs.gym import GymEnv
51+
52+
53+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
54+
def __init__(self, policy_weights, num_workers):
55+
# Weights are on mps device, which cannot be shared
56+
self.policy_weights = policy_weights.data
57+
self.num_workers = num_workers
58+
59+
def _sync_weights_with_worker(
60+
self, worker_id: int | torch.device, server_weights: TensorDictBase
61+
) -> TensorDictBase:
62+
# Send weights on cpu - the local workers will do the cpu->mps copy
63+
self.collector.pipes[worker_id].send((server_weights, "update"))
64+
val, msg = self.collector.pipes[worker_id].recv()
65+
assert msg == "updated"
66+
return server_weights
67+
68+
def _get_server_weights(self) -> TensorDictBase:
69+
print((self.policy_weights == 0).all())
70+
return self.policy_weights.cpu()
71+
72+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
73+
print((server_weights == 0).all())
74+
return server_weights
75+
76+
def all_worker_ids(self) -> list[int] | list[torch.device]:
77+
return list(range(self.num_workers))
78+
79+
80+
if __name__ == "__main__":
81+
device = "mps"
82+
83+
def env_maker():
84+
return GymEnv("Pendulum-v1", device="cpu")
85+
86+
def policy_factory(device=device):
87+
return TensorDictModule(
88+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
89+
).to(device=device)
90+
91+
policy = policy_factory()
92+
policy_weights = tensordict.from_module(policy)
93+
94+
collector = MultiSyncDataCollector(
95+
create_env_fn=[env_maker, env_maker],
96+
policy_factory=policy_factory,
97+
total_frames=2000,
98+
max_frames_per_traj=50,
99+
frames_per_batch=200,
100+
init_random_frames=-1,
101+
reset_at_each_iter=False,
102+
device=device,
103+
storing_device="cpu",
104+
remote_weights_updater=MPSRemoteWeightUpdater(policy_weights, 2),
105+
# use_buffers=False,
106+
# cat_results="stack",
107+
)
108+
109+
collector.update_policy_weights_()
110+
try:
111+
for i, data in enumerate(collector):
112+
if i == 2:
113+
print(data)
114+
assert (data["action"] != 0).any()
115+
# zero the policy
116+
policy_weights.data.zero_()
117+
collector.update_policy_weights_()
118+
elif i == 3:
119+
assert (data["action"] == 0).all(), data["action"]
120+
break
121+
finally:
122+
collector.shutdown()

test/test_collector.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
prod,
4040
seed_generator,
4141
)
42-
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
42+
from torchrl.collectors import (
43+
aSyncDataCollector,
44+
RemoteWeightUpdaterBase,
45+
SyncDataCollector,
46+
)
4347
from torchrl.collectors.collectors import (
4448
_Interruptor,
4549
MultiaSyncDataCollector,
@@ -146,6 +150,7 @@
146150
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
147151
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
148152
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
153+
_has_cuda = torch.cuda.is_available()
149154

150155

151156
class WrappablePolicy(nn.Module):
@@ -3476,6 +3481,69 @@ def __deepcopy_error__(*args, **kwargs):
34763481
raise RuntimeError("deepcopy not allowed")
34773482

34783483

3484+
class TestPolicyFactory:
3485+
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
3486+
def __init__(self, policy_weights, num_workers):
3487+
# Weights are on mps device, which cannot be shared
3488+
self.policy_weights = policy_weights.data
3489+
self.num_workers = num_workers
3490+
3491+
def _sync_weights_with_worker(
3492+
self, worker_id: int | torch.device, server_weights: TensorDictBase
3493+
) -> TensorDictBase:
3494+
# Send weights on cpu - the local workers will do the cpu->mps copy
3495+
self.collector.pipes[worker_id].send((server_weights, "update"))
3496+
val, msg = self.collector.pipes[worker_id].recv()
3497+
assert msg == "updated"
3498+
return server_weights
3499+
3500+
def _get_server_weights(self) -> TensorDictBase:
3501+
return self.policy_weights.cpu()
3502+
3503+
def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
3504+
return server_weights
3505+
3506+
def all_worker_ids(self) -> list[int] | list[torch.device]:
3507+
return list(range(self.num_workers))
3508+
3509+
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
3510+
def test_weight_update(self):
3511+
device = "cuda:0"
3512+
env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
3513+
policy_factory = lambda: TensorDictModule(
3514+
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
3515+
).to(device)
3516+
policy = policy_factory()
3517+
policy_weights = TensorDict.from_module(policy)
3518+
3519+
collector = MultiSyncDataCollector(
3520+
create_env_fn=[env_maker, env_maker],
3521+
policy_factory=policy_factory,
3522+
total_frames=2000,
3523+
max_frames_per_traj=50,
3524+
frames_per_batch=200,
3525+
init_random_frames=-1,
3526+
reset_at_each_iter=False,
3527+
device=device,
3528+
storing_device="cpu",
3529+
remote_weights_updater=self.MPSRemoteWeightUpdater(policy_weights, 2),
3530+
)
3531+
3532+
collector.update_policy_weights_()
3533+
try:
3534+
for i, data in enumerate(collector):
3535+
if i == 2:
3536+
assert (data["action"] != 0).any()
3537+
# zero the policy
3538+
policy_weights.data.zero_()
3539+
collector.update_policy_weights_()
3540+
elif i == 3:
3541+
assert (data["action"] == 0).all(), data["action"]
3542+
break
3543+
finally:
3544+
collector.shutdown()
3545+
3546+
34793547
if __name__ == "__main__":
34803548
args, unknown = argparse.ArgumentParser().parse_known_args()
34813549
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,9 @@ def __init__(
837837
)
838838

839839
self.local_weights_updater = local_weights_updater
840+
if remote_weights_updater is not None:
841+
remote_weights_updater.register_collector(self)
842+
840843
self.remote_weights_updater = remote_weights_updater
841844

842845
@property
@@ -1827,10 +1830,13 @@ def __init__(
18271830
"remote_weights_updater cannot be None when policy_factory is provided."
18281831
)
18291832

1833+
if remote_weights_updater is not None:
1834+
remote_weights_updater.register_collector(self)
18301835
self.remote_weights_updater = remote_weights_updater
18311836
self.local_weights_updater = local_weights_updater
18321837

18331838
self.policy = policy
1839+
self.policy_factory = policy_factory
18341840

18351841
remainder = 0
18361842
if total_frames is None or total_frames < 0:
@@ -2012,6 +2018,10 @@ def _run_processes(self) -> None:
20122018
env_fun = CloudpickleWrapper(env_fun)
20132019

20142020
# Create a policy on the right device
2021+
policy_factory = self.policy_factory
2022+
if policy_factory is not None:
2023+
policy_factory = CloudpickleWrapper(policy_factory)
2024+
20152025
policy_device = self.policy_device[i]
20162026
storing_device = self.storing_device[i]
20172027
env_device = self.env_device[i]
@@ -2020,13 +2030,14 @@ def _run_processes(self) -> None:
20202030
# This makes sure that a given set of shared weights for a given device are
20212031
# shared for all policies that rely on that device.
20222032
policy = self.policy
2023-
policy_weights = self._policy_weights_dict[policy_device]
2033+
policy_weights = self._policy_weights_dict.get(policy_device)
20242034
if policy is not None and policy_weights is not None:
20252035
cm = policy_weights.to_module(policy)
20262036
else:
20272037
cm = contextlib.nullcontext()
20282038
with cm:
20292039
kwargs = {
2040+
"policy_factory": policy_factory,
20302041
"pipe_parent": pipe_parent,
20312042
"pipe_child": pipe_child,
20322043
"queue_out": queue_out,
@@ -3107,6 +3118,7 @@ def _main_async_collector(
31073118
compile_policy: bool = False,
31083119
cudagraph_policy: bool = False,
31093120
no_cuda_sync: bool = False,
3121+
policy_factory: Callable | None = None,
31103122
) -> None:
31113123
pipe_parent.close()
31123124
# init variables that will be cleared when closing
@@ -3116,6 +3128,7 @@ def _main_async_collector(
31163128
create_env_fn,
31173129
create_env_kwargs=create_env_kwargs,
31183130
policy=policy,
3131+
policy_factory=policy_factory,
31193132
total_frames=-1,
31203133
max_frames_per_traj=max_frames_per_traj,
31213134
frames_per_batch=frames_per_batch,
@@ -3278,7 +3291,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
32783291
continue
32793292

32803293
elif msg == "update":
3281-
inner_collector.update_policy_weights_()
3294+
inner_collector.update_policy_weights_(policy_weights=data_in)
32823295
pipe_child.send((j, "updated"))
32833296
has_timed_out = False
32843297
continue

torchrl/collectors/weight_update.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import abc
88
from abc import abstractmethod
9-
from typing import Callable, TypeVar
9+
from typing import Any, Callable, TypeVar
1010

1111
import torch
1212
from tensordict import TensorDictBase
@@ -110,6 +110,21 @@ class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta):
110110
111111
"""
112112

113+
collector: Any = None
114+
115+
def register_collector(self, collector: DataCollectorBase): # noqa
116+
"""Register a collector in the updater.
117+
118+
Once registered, the updater will not accept another collector.
119+
120+
Args:
121+
collector (DataCollectorBase): The collector to register.
122+
123+
"""
124+
if self.collector is not None:
125+
raise RuntimeError("Cannot register collector twice.")
126+
self.collector = collector
127+
113128
@abstractmethod
114129
def _sync_weights_with_worker(
115130
self, worker_id: int | torch.device, server_weights: TensorDictBase

torchrl/data/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,15 @@ def contains_lazy_spec(spec: TensorSpec) -> bool:
222222
return False
223223

224224

225-
class CloudpickleWrapper:
225+
class _CloudpickleWrapperMeta(type):
226+
def __call__(cls, obj):
227+
if isinstance(obj, cls):
228+
return obj
229+
else:
230+
return super().__call__(obj)
231+
232+
233+
class CloudpickleWrapper(metaclass=_CloudpickleWrapperMeta):
226234
"""A wrapper for functions that allow for serialization in multiprocessed settings."""
227235

228236
def __init__(self, fn: Callable, **kwargs):

0 commit comments

Comments
 (0)