Skip to content

[Feature] pass policy-factory in mp data collectors #2859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/collectors/mp_collector_mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Updating MPS weights in multiprocess/distributed data collectors
================================================================

Overview of the Script
----------------------

This script demonstrates a weight update in TorchRL.
The script uses a custom `MPSRemoteWeightUpdater` class to update the weights of a policy network across multiple workers.

Key Features
------------

- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
("Pendulum-v1") using a policy network.
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
- Custom Weight Updater: The `MPSRemoteWeightUpdater` class is used to update the policy weights across workers. This
class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.

Workaround for MPS Tensor Serialization Issue
---------------------------------------------

In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
processes. To work around this issue, the MPSRemoteWeightUpdater class sends the policy weights on the CPU device
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.

Script Flow
-----------

1. Initialize the environment, policy network, and collector.
2. Update the policy weights using the MPSRemoteWeightUpdater.
3. Collect data from the environment using the policy network.
4. Zero out the policy weights after a few iterations.
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.

"""

import tensordict
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector, RemoteWeightUpdaterBase

from torchrl.envs.libs.gym import GymEnv


class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
self.num_workers = num_workers

def _sync_weights_with_worker(
self, worker_id: int | torch.device, server_weights: TensorDictBase
) -> TensorDictBase:
# Send weights on cpu - the local workers will do the cpu->mps copy
self.collector.pipes[worker_id].send((server_weights, "update"))
val, msg = self.collector.pipes[worker_id].recv()
assert msg == "updated"
return server_weights

def _get_server_weights(self) -> TensorDictBase:
print((self.policy_weights == 0).all())
return self.policy_weights.cpu()

def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
print((server_weights == 0).all())
return server_weights

def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(range(self.num_workers))


if __name__ == "__main__":
device = "mps"

def env_maker():
return GymEnv("Pendulum-v1", device="cpu")

def policy_factory(device=device):
return TensorDictModule(
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
).to(device=device)

policy = policy_factory()
policy_weights = tensordict.from_module(policy)

collector = MultiSyncDataCollector(
create_env_fn=[env_maker, env_maker],
policy_factory=policy_factory,
total_frames=2000,
max_frames_per_traj=50,
frames_per_batch=200,
init_random_frames=-1,
reset_at_each_iter=False,
device=device,
storing_device="cpu",
remote_weight_updater=MPSRemoteWeightUpdater(policy_weights, 2),
# use_buffers=False,
# cat_results="stack",
)

collector.update_policy_weights_()
try:
for i, data in enumerate(collector):
if i == 2:
print(data)
assert (data["action"] != 0).any()
# zero the policy
policy_weights.data.zero_()
collector.update_policy_weights_()
elif i == 3:
assert (data["action"] == 0).all(), data["action"]
break
finally:
collector.shutdown()
70 changes: 69 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
prod,
seed_generator,
)
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.collectors import (
aSyncDataCollector,
RemoteWeightUpdaterBase,
SyncDataCollector,
)
from torchrl.collectors.collectors import (
_Interruptor,
MultiaSyncDataCollector,
Expand Down Expand Up @@ -146,6 +150,7 @@
PYTHON_3_10 = sys.version_info.major == 3 and sys.version_info.minor == 10
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
_has_cuda = torch.cuda.is_available()


class WrappablePolicy(nn.Module):
Expand Down Expand Up @@ -3476,6 +3481,69 @@ def __deepcopy_error__(*args, **kwargs):
raise RuntimeError("deepcopy not allowed")


class TestPolicyFactory:
class MPSRemoteWeightUpdater(RemoteWeightUpdaterBase):
def __init__(self, policy_weights, num_workers):
# Weights are on mps device, which cannot be shared
self.policy_weights = policy_weights.data
self.num_workers = num_workers

def _sync_weights_with_worker(
self, worker_id: int | torch.device, server_weights: TensorDictBase
) -> TensorDictBase:
# Send weights on cpu - the local workers will do the cpu->mps copy
self.collector.pipes[worker_id].send((server_weights, "update"))
val, msg = self.collector.pipes[worker_id].recv()
assert msg == "updated"
return server_weights

def _get_server_weights(self) -> TensorDictBase:
return self.policy_weights.cpu()

def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
return server_weights

def all_worker_ids(self) -> list[int] | list[torch.device]:
return list(range(self.num_workers))

@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
def test_weight_update(self):
device = "cuda:0"
env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
policy_factory = lambda: TensorDictModule(
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
).to(device)
policy = policy_factory()
policy_weights = TensorDict.from_module(policy)

collector = MultiSyncDataCollector(
create_env_fn=[env_maker, env_maker],
policy_factory=policy_factory,
total_frames=2000,
max_frames_per_traj=50,
frames_per_batch=200,
init_random_frames=-1,
reset_at_each_iter=False,
device=device,
storing_device="cpu",
remote_weight_updater=self.MPSRemoteWeightUpdater(policy_weights, 2),
)

collector.update_policy_weights_()
try:
for i, data in enumerate(collector):
if i == 2:
assert (data["action"] != 0).any()
# zero the policy
policy_weights.data.zero_()
collector.update_policy_weights_()
elif i == 3:
assert (data["action"] == 0).all(), data["action"]
break
finally:
collector.shutdown()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading
Loading