Skip to content

[BUG] DataCollectors fail when device is set to MPS #2858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
3 tasks done
LCarmi opened this issue Mar 18, 2025 · 0 comments · Fixed by #2859
Open
3 tasks done

[BUG] DataCollectors fail when device is set to MPS #2858

LCarmi opened this issue Mar 18, 2025 · 0 comments · Fixed by #2859
Assignees
Labels
bug Something isn't working

Comments

@LCarmi
Copy link

LCarmi commented Mar 18, 2025

Describe the bug

When running experiments with multiprocess-based sampling of trajectories on macOS, the initialization of the data collectors fail

To Reproduce

from torchrl.envs.libs.gym import GymEnv
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector

if __name__ == "__main__":
    env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
    policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
    collector = MultiSyncDataCollector(
        create_env_fn=[env_maker, env_maker],
        policy=policy,
        total_frames=2000,
        max_frames_per_traj=50,
        frames_per_batch=200,
        init_random_frames=-1,
        reset_at_each_iter=False,
        device="mps",
        storing_device="cpu",
        # cat_results="stack",
    )
    for i, data in enumerate(collector):
        if i == 2:
            print(data)
            break

This fails as follows:

Traceback (most recent call last):
  File "..././torchrl_test_mps_fail.py", line 9, in <module>
    collector = MultiSyncDataCollector(
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1779, in __init__
    self._run_processes()
  File ".../.venv/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1976, in _run_processes
    proc.start()
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/nix/store/ra1l4hyhxw3zlq62y8vg6fpxysq9ln6s-python3-3.10.16/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File ".../.venv/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 607, in reduce_storage
    metadata = storage._share_filename_cpu_()
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 450, in wrapper
    return fn(self, *args, **kwargs)
  File ".../.venv/lib/python3.10/site-packages/torch/storage.py", line 529, in _share_filename_cpu_
    return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU

System info

>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.2 2.2.4 3.10.16 (main, Dec  3 2024, 17:27:57) [Clang 16.0.6 ] darwin

Reason and Possible fixes

I suspect this issue boils down to:

  • limitations of mps device, which does not work well with a pickle-based sharing of parameters
  • limitations of torchrl , which assume a spawn-based multiprocessing library
    • as opposed to a fork-based multiprocess context; forcing fork through multiprocessing.set_start_method('fork') gives a warning and makes collectors crash
    • a spawn context is imposed by torchrl
      mp.set_start_method("spawn")
  • spawn multiprocessing context using pickle to copy the state of a process on a newly spawned one

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug
@LCarmi LCarmi added the bug Something isn't working label Mar 18, 2025
@vmoens vmoens linked a pull request Mar 18, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants