Skip to content

Commit c39bd60

Browse files
yvonne-labfacebook-github-bot
authored andcommitted
Use spawn instead of forkserver for MTIA (#2691)
Summary: Pull Request resolved: #2691 This is to support multiprocessing for MTIA. MTIA runtime doesn't seem to work with forkserver; folly::Singleton cannot work with fork. Switch to `spawn` when running on MTIA. Error paste: P1684403488 Reviewed By: qcyuan Differential Revision: D66351758 fbshipit-source-id: eee0bc4bb7a0fa527d0b318c8b5fac18564fe05e
1 parent f52fd32 commit c39bd60

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchrec/distributed/test_utils/multi_process.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#!/usr/bin/env python3
1111

12+
import logging
1213
import multiprocessing
1314
import os
1415
import unittest
@@ -24,11 +25,6 @@
2425
)
2526

2627

27-
# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
28-
# Therefore we use spawn for HIP runtime until AMD fixes the issue
29-
_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn"
30-
31-
3228
class MultiProcessContext:
3329
def __init__(
3430
self,
@@ -98,6 +94,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:
9894

9995

10096
class MultiProcessTestBase(unittest.TestCase):
97+
def __init__(
98+
self, methodName: str = "runTest", mp_init_mode: str = "forkserver"
99+
) -> None:
100+
super().__init__(methodName)
101+
102+
# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
103+
# Therefore we use spawn for HIP runtime until AMD fixes the issue
104+
self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn"
105+
logging.info(f"Using {self._mp_init_mode} for multiprocessing")
101106

102107
@seed_and_log
103108
def setUp(self) -> None:
@@ -131,7 +136,7 @@ def _run_multi_process_test(
131136
# pyre-ignore
132137
**kwargs,
133138
) -> None:
134-
ctx = multiprocessing.get_context(_MP_INIT_MODE)
139+
ctx = multiprocessing.get_context(self._mp_init_mode)
135140
processes = []
136141
for rank in range(world_size):
137142
kwargs["rank"] = rank
@@ -157,7 +162,7 @@ def _run_multi_process_test_per_rank(
157162
world_size: int,
158163
kwargs_per_rank: List[Dict[str, Any]],
159164
) -> None:
160-
ctx = multiprocessing.get_context(_MP_INIT_MODE)
165+
ctx = multiprocessing.get_context(self._mp_init_mode)
161166
processes = []
162167
for rank in range(world_size):
163168
kwargs = {}

0 commit comments

Comments
 (0)