Skip to content

Commit 1ab1381

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Collection of return values in Multiprocess functions runner (#3235)
Summary: Pull Request resolved: #3235 Added functionality to collect return values in our multiprocess runner. Reviewed By: jd7-tr Differential Revision: D78939650 fbshipit-source-id: ae3b284519840e498a6c03bb27d2ba008b4d05b0
1 parent d945603 commit 1ab1381

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,27 @@ def _run_multi_process_test_per_rank(
189189
self.assertEqual(0, p.exitcode)
190190

191191

192+
def _wrapper_func_for_multiprocessing(args): # pyre-ignore[2, 3]
193+
"""Wrapper function that unpacks arguments and calls the original func"""
194+
func, rank, world_size, kwargs = args
195+
kwargs["rank"] = rank
196+
kwargs["world_size"] = world_size
197+
return func(**kwargs)
198+
199+
200+
# pyre-ignore[3]
192201
def run_multi_process_func(
202+
# pyre-ignore[2]
193203
func: Callable[
194204
[int, int, ...], # rank, world_size, ...
195-
None,
205+
Any, # Changed from None to Any to allow return values
196206
],
197207
multiprocessing_method: str = "spawn",
198208
use_deterministic_algorithms: bool = True,
199209
world_size: int = 2,
200210
# pyre-ignore
201211
**kwargs,
202-
) -> None:
212+
) -> List[Any]:
203213
""" """
204214
os.environ["MASTER_ADDR"] = str("localhost")
205215
os.environ["MASTER_PORT"] = str(get_free_port())
@@ -215,22 +225,16 @@ def run_multi_process_func(
215225
if world_size == 1:
216226
kwargs["world_size"] = 1
217227
kwargs["rank"] = 0
218-
func(**kwargs)
219-
return
228+
result = func(**kwargs)
229+
return [result]
230+
220231
ctx = multiprocessing.get_context(multiprocessing_method)
221-
processes = []
222-
for rank in range(world_size):
223-
kwargs["rank"] = rank
224-
kwargs["world_size"] = world_size
225-
p = ctx.Process(
226-
target=func,
227-
name=f"rank{rank}",
228-
kwargs=kwargs,
229-
)
230-
p.start()
231-
processes.append(p)
232232

233-
for p in processes:
234-
p.join()
235-
if p.exitcode != 0:
236-
print(p)
233+
# Prepare arguments for each process
234+
args_list = [(func, rank, world_size, kwargs.copy()) for rank in range(world_size)]
235+
236+
# Create a pool of worker processes for each rank
237+
with ctx.Pool(processes=world_size) as pool:
238+
results = pool.map(_wrapper_func_for_multiprocessing, args_list)
239+
240+
return results

0 commit comments

Comments
 (0)