From b5f45eb2e9c7a3bcbbaf4b6537335242ff516441 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 10 Mar 2025 15:38:19 -0700 Subject: [PATCH 1/2] Cleanup process group pipe shutdown --- torchft/process_group.py | 76 +++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 20 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 00a0c4a..aa55367 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -16,6 +16,7 @@ runtime users need to take care to not assume a static rank or world size. """ +import atexit import logging import threading from contextlib import contextmanager, nullcontext @@ -75,7 +76,7 @@ logger: logging.Logger = logging.getLogger(__name__) # TODO: use non strings which are cheaper -_QUEUE_CLOSE = "queue_close" +_PIPE_CLOSE = "pipe_close" _FUTURE_RESULT = "fut_result" _FUTURE_EXCEPTION = "fut_exception" @@ -940,36 +941,67 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: self._timeout: float = timeout + # Register the shutdown method to be called at exit + atexit.register(self.shutdown) + def shutdown(self) -> None: """ Shutdown the process group. This will kill the underlying process and - close all queues. + close all pipes. This is a no-op if the process group is already shutdown. ProcessGroup can be reconfigured after shutdown. """ - + # Close the future pipe first + if self._future_pipe is not None: + # close future thread + self._future_pipe.send((-1, _PIPE_CLOSE, None, None)) + assert self._future_pipe is not None + self._future_pipe.close() + self._future_pipe = None + # Join the future thread after closing its pipe + if self._future_thread is not None: + self._future_thread.join(timeout=10.0) + assert self._future_thread is not None + if self._future_thread.is_alive(): + raise RuntimeError("Future thread did not exit") + self._future_thread = None + # Close the request pipe to signal the worker process to exit if self._pipe is not None: + self._pipe.send((_PIPE_CLOSE,)) + assert self._pipe is not None self._pipe.close() - - future_pipe = self._future_pipe - if future_pipe is not None: - # wait for the future thread to exit and then close the queue - future_pipe.close() - - future_thread = self._future_thread - assert future_thread is not None - - future_thread.join(timeout=10.0) - if future_thread.is_alive(): - raise RuntimeError("future thread did not exit") - - # Kill after closing queues to avoid log spam. + self._pipe = None + # Terminate the worker process after closing its pipe if self._p is not None: - self._p.kill() + self._p.join(timeout=10.0) + assert self._p is not None + if self._p.is_alive(): + raise RuntimeError("Worker process did not exit") + self._p = None def configure(self, store_addr: str, rank: int, world_size: int) -> None: + """ + Structure + +-------------------+ + | | + | Main Process | (updates futures) + | | <--------------- + +-------------------+ | + | Pipe 1 | + v | + +-------------------+ +-------------------+ + | | | | + | Worker Process | -> | Future Thread | + | | Pipe 2 | | + +-------------------+ +-------------------+ + + Main Process: Maintains self._futures + Worker Process: Handles tasks, communicates with Future Thread. + Future Thread: Manages asynchronous tasks, updates self._futures. + """ + self._world_size = world_size self.shutdown() @@ -990,7 +1022,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: rank, world_size, req_remote, - future_remote, + future_local, curr_device, ), daemon=True, @@ -1003,7 +1035,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._futures = {} self._future_thread = threading.Thread( target=self._future_handler, - args=(future_local,), + args=(_MonitoredPipe(future_remote),), daemon=True, ) self._future_thread.start() @@ -1049,6 +1081,8 @@ def _worker( while True: op = cast(list[object], req_pipe.recv()) cmd = op[0] + if cmd == _PIPE_CLOSE: + break if cmd == "func": op_id: int op_id, func_name, args, kwargs, stream_device, stream_id, event = ( @@ -1172,6 +1206,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None: op_id, mode, data, event = cast( Tuple[int, str, object, Optional[torch.cuda.Event]], cmd ) + if mode == _PIPE_CLOSE: + break with self._futures_lock: fut = self._futures[op_id] del self._futures[op_id] From a03c35d18f2f92f99dd052ff412f730eba751d81 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 12 Mar 2025 07:34:41 -0700 Subject: [PATCH 2/2] rebase / fix comments --- torchft/process_group.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index aa55367..8322fa8 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -949,31 +949,34 @@ def shutdown(self) -> None: Shutdown the process group. This will kill the underlying process and close all pipes. + We close the pipes by sending a _PIPE_CLOSE message from the writing end (local) + to the reading end (remote). The remote end will then exit it's recv loop and we + will join the thread or process. + This is a no-op if the process group is already shutdown. ProcessGroup can be reconfigured after shutdown. """ - # Close the future pipe first + # Close the future pipe if self._future_pipe is not None: - # close future thread self._future_pipe.send((-1, _PIPE_CLOSE, None, None)) assert self._future_pipe is not None self._future_pipe.close() self._future_pipe = None - # Join the future thread after closing its pipe + # Join the future thread if self._future_thread is not None: self._future_thread.join(timeout=10.0) assert self._future_thread is not None if self._future_thread.is_alive(): raise RuntimeError("Future thread did not exit") self._future_thread = None - # Close the request pipe to signal the worker process to exit + # Close the process pipe if self._pipe is not None: self._pipe.send((_PIPE_CLOSE,)) assert self._pipe is not None self._pipe.close() self._pipe = None - # Terminate the worker process after closing its pipe + # Join the process if self._p is not None: self._p.join(timeout=10.0) assert self._p is not None @@ -997,9 +1000,8 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: | | Pipe 2 | | +-------------------+ +-------------------+ - Main Process: Maintains self._futures - Worker Process: Handles tasks, communicates with Future Thread. - Future Thread: Manages asynchronous tasks, updates self._futures. + Worker Process: Executes the collective operations. + Future Thread: Executes the user defined future callbacks. """ self._world_size = world_size