Skip to content

[WIP Fix pipe close warnings #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
runtime users need to take care to not assume a static rank or world size.
"""

import atexit
import logging
import threading
from contextlib import contextmanager, nullcontext
Expand Down Expand Up @@ -75,7 +76,7 @@
logger: logging.Logger = logging.getLogger(__name__)

# TODO: use non strings which are cheaper
_QUEUE_CLOSE = "queue_close"
_PIPE_CLOSE = "pipe_close"
_FUTURE_RESULT = "fut_result"
_FUTURE_EXCEPTION = "fut_exception"

Expand Down Expand Up @@ -940,36 +941,69 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:

self._timeout: float = timeout

# Register the shutdown method to be called at exit
atexit.register(self.shutdown)

def shutdown(self) -> None:
"""
Shutdown the process group. This will kill the underlying process and
close all queues.
close all pipes.

We close the pipes by sending a _PIPE_CLOSE message from the writing end (local)
to the reading end (remote). The remote end will then exit it's recv loop and we
will join the thread or process.

This is a no-op if the process group is already shutdown.

ProcessGroup can be reconfigured after shutdown.
"""

# Close the future pipe
if self._future_pipe is not None:
self._future_pipe.send((-1, _PIPE_CLOSE, None, None))
assert self._future_pipe is not None
self._future_pipe.close()
self._future_pipe = None
# Join the future thread
if self._future_thread is not None:
self._future_thread.join(timeout=10.0)
assert self._future_thread is not None
if self._future_thread.is_alive():
raise RuntimeError("Future thread did not exit")
self._future_thread = None
# Close the process pipe
if self._pipe is not None:
self._pipe.send((_PIPE_CLOSE,))
assert self._pipe is not None
self._pipe.close()

future_pipe = self._future_pipe
if future_pipe is not None:
# wait for the future thread to exit and then close the queue
future_pipe.close()

future_thread = self._future_thread
assert future_thread is not None

future_thread.join(timeout=10.0)
if future_thread.is_alive():
raise RuntimeError("future thread did not exit")

# Kill after closing queues to avoid log spam.
self._pipe = None
# Join the process
if self._p is not None:
self._p.kill()
self._p.join(timeout=10.0)
assert self._p is not None
if self._p.is_alive():
raise RuntimeError("Worker process did not exit")
self._p = None

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
"""
Structure
+-------------------+
| |
| Main Process | (updates futures)
| | <---------------
+-------------------+ |
| Pipe 1 |
v |
+-------------------+ +-------------------+
| | | |
| Worker Process | -> | Future Thread |
| | Pipe 2 | |
+-------------------+ +-------------------+

Worker Process: Executes the collective operations.
Future Thread: Executes the user defined future callbacks.
"""

self._world_size = world_size

self.shutdown()
Expand All @@ -990,7 +1024,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
rank,
world_size,
req_remote,
future_remote,
future_local,
curr_device,
),
daemon=True,
Expand All @@ -1003,7 +1037,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
self._futures = {}
self._future_thread = threading.Thread(
target=self._future_handler,
args=(future_local,),
args=(_MonitoredPipe(future_remote),),
daemon=True,
)
self._future_thread.start()
Expand Down Expand Up @@ -1049,6 +1083,8 @@ def _worker(
while True:
op = cast(list[object], req_pipe.recv())
cmd = op[0]
if cmd == _PIPE_CLOSE:
break
if cmd == "func":
op_id: int
op_id, func_name, args, kwargs, stream_device, stream_id, event = (
Expand Down Expand Up @@ -1172,6 +1208,8 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
op_id, mode, data, event = cast(
Tuple[int, str, object, Optional[torch.cuda.Event]], cmd
)
if mode == _PIPE_CLOSE:
break
with self._futures_lock:
fut = self._futures[op_id]
del self._futures[op_id]
Expand Down
Loading