diff --git a/torchft/process_group.py b/torchft/process_group.py index 00a0c4a..8322fa8 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -16,6 +16,7 @@ runtime users need to take care to not assume a static rank or world size. """ +import atexit import logging import threading from contextlib import contextmanager, nullcontext @@ -75,7 +76,7 @@ logger: logging.Logger = logging.getLogger(__name__) # TODO: use non strings which are cheaper -_QUEUE_CLOSE = "queue_close" +_PIPE_CLOSE = "pipe_close" _FUTURE_RESULT = "fut_result" _FUTURE_EXCEPTION = "fut_exception" @@ -940,36 +941,69 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: self._timeout: float = timeout + # Register the shutdown method to be called at exit + atexit.register(self.shutdown) + def shutdown(self) -> None: """ Shutdown the process group. This will kill the underlying process and - close all queues. + close all pipes. + + We close the pipes by sending a _PIPE_CLOSE message from the writing end (local) + to the reading end (remote). The remote end will then exit it's recv loop and we + will join the thread or process. This is a no-op if the process group is already shutdown. ProcessGroup can be reconfigured after shutdown. """ - + # Close the future pipe + if self._future_pipe is not None: + self._future_pipe.send((-1, _PIPE_CLOSE, None, None)) + assert self._future_pipe is not None + self._future_pipe.close() + self._future_pipe = None + # Join the future thread + if self._future_thread is not None: + self._future_thread.join(timeout=10.0) + assert self._future_thread is not None + if self._future_thread.is_alive(): + raise RuntimeError("Future thread did not exit") + self._future_thread = None + # Close the process pipe if self._pipe is not None: + self._pipe.send((_PIPE_CLOSE,)) + assert self._pipe is not None self._pipe.close() - - future_pipe = self._future_pipe - if future_pipe is not None: - # wait for the future thread to exit and then close the queue - future_pipe.close() - - future_thread = self._future_thread - assert future_thread is not None - - future_thread.join(timeout=10.0) - if future_thread.is_alive(): - raise RuntimeError("future thread did not exit") - - # Kill after closing queues to avoid log spam. + self._pipe = None + # Join the process if self._p is not None: - self._p.kill() + self._p.join(timeout=10.0) + assert self._p is not None + if self._p.is_alive(): + raise RuntimeError("Worker process did not exit") + self._p = None def configure(self, store_addr: str, rank: int, world_size: int) -> None: + """ + Structure + +-------------------+ + | | + | Main Process | (updates futures) + | | <--------------- + +-------------------+ | + | Pipe 1 | + v | + +-------------------+ +-------------------+ + | | | | + | Worker Process | -> | Future Thread | + | | Pipe 2 | | + +-------------------+ +-------------------+ + + Worker Process: Executes the collective operations. + Future Thread: Executes the user defined future callbacks. + """ + self._world_size = world_size self.shutdown() @@ -990,7 +1024,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: rank, world_size, req_remote, - future_remote, + future_local, curr_device, ), daemon=True, @@ -1003,7 +1037,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._futures = {} self._future_thread = threading.Thread( target=self._future_handler, - args=(future_local,), + args=(_MonitoredPipe(future_remote),), daemon=True, ) self._future_thread.start() @@ -1049,6 +1083,8 @@ def _worker( while True: op = cast(list[object], req_pipe.recv()) cmd = op[0] + if cmd == _PIPE_CLOSE: + break if cmd == "func": op_id: int op_id, func_name, args, kwargs, stream_device, stream_id, event = ( @@ -1172,6 +1208,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None: op_id, mode, data, event = cast( Tuple[int, str, object, Optional[torch.cuda.Event]], cmd ) + if mode == _PIPE_CLOSE: + break with self._futures_lock: fut = self._futures[op_id] del self._futures[op_id]