Skip to content

Commit 21e2367

Browse files
committed
abort PG on error
1 parent 2ab329e commit 21e2367

File tree

3 files changed

+313
-155
lines changed

3 files changed

+313
-155
lines changed

Diff for: torchft/futures.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
22
import threading
33
from datetime import timedelta
4-
from typing import Optional, TypeVar
4+
from typing import Callable, Optional, TypeVar
55
from unittest.mock import Mock
66

7+
import torch
78
from torch.futures import Future
89

910
T = TypeVar("T")
@@ -17,7 +18,6 @@ def __init__(self) -> None:
1718

1819
def set_timer(self, timer_handle: asyncio.TimerHandle) -> None:
1920
assert self._lock.locked()
20-
2121
self._timer_handle = timer_handle
2222
self._lock.release()
2323

@@ -99,6 +99,18 @@ def callback(fut: Future[T]) -> None:
9999
fut.add_done_callback(callback)
100100
return timed_fut
101101

102+
def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
103+
loop = self._maybe_start_event_loop()
104+
105+
event = torch.cuda.Event()
106+
event.record()
107+
108+
def handler() -> None:
109+
if not event.query():
110+
callback()
111+
112+
loop.call_soon_threadsafe(self._register_handler, loop, handler, timeout)
113+
102114
@classmethod
103115
def _register(
104116
cls,
@@ -116,6 +128,18 @@ def _register(
116128
)
117129
handle.set_timer(timer_handle)
118130

131+
@classmethod
132+
def _register_handler(
133+
cls,
134+
loop,
135+
handler: Callable[[], None],
136+
timeout: timedelta,
137+
) -> None:
138+
loop.call_later(
139+
timeout.total_seconds(),
140+
handler,
141+
)
142+
119143

120144
_TIMEOUT_MANAGER = _TimeoutManager()
121145

@@ -163,3 +187,18 @@ def callback(fut: Future[T]) -> T:
163187
raise TimeoutError(f"future did not complete within {timeout}")
164188

165189
return fut.wait()
190+
191+
192+
def stream_timeout(callback: Callable[[], None], timeout: timedelta) -> None:
193+
"""
194+
Registers a callback that will be called after the specified timeout if
195+
the current stream doesn't complete in time.
196+
197+
This uses a cuda Event to track the completion of the current stream. If
198+
the stream is not complete after the timeout, the callback is called.
199+
200+
Args:
201+
callback: The callback to call if the stream doesn't complete in time.
202+
timeout: The timeout to wait for the stream to complete.
203+
"""
204+
_TIMEOUT_MANAGER.stream_timeout(callback, timeout)

Diff for: torchft/manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from concurrent.futures import ThreadPoolExecutor
3434
from datetime import timedelta
3535
from enum import Enum
36-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
36+
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar
3737

3838
import torch
3939
from torch.distributed import ReduceOp, TCPStore
@@ -477,7 +477,7 @@ def _async_quorum(
477477
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
478478
self._quorum_id = quorum_id
479479

480-
if allow_heal:
480+
if allow_heal and False:
481481
if quorum.recover_dst_ranks:
482482
self._logger.info(
483483
f"peers need recovery from us {quorum.recover_dst_ranks}"

0 commit comments

Comments
 (0)