1
1
import asyncio
2
2
import threading
3
3
from datetime import timedelta
4
- from typing import Optional , TypeVar
4
+ from typing import Callable , Optional , TypeVar
5
5
from unittest .mock import Mock
6
6
7
+ import torch
7
8
from torch .futures import Future
8
9
9
10
T = TypeVar ("T" )
@@ -17,7 +18,6 @@ def __init__(self) -> None:
17
18
18
19
def set_timer (self , timer_handle : asyncio .TimerHandle ) -> None :
19
20
assert self ._lock .locked ()
20
-
21
21
self ._timer_handle = timer_handle
22
22
self ._lock .release ()
23
23
@@ -99,6 +99,18 @@ def callback(fut: Future[T]) -> None:
99
99
fut .add_done_callback (callback )
100
100
return timed_fut
101
101
102
+ def stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) -> None :
103
+ loop = self ._maybe_start_event_loop ()
104
+
105
+ event = torch .cuda .Event ()
106
+ event .record ()
107
+
108
+ def handler () -> None :
109
+ if not event .query ():
110
+ callback ()
111
+
112
+ loop .call_soon_threadsafe (self ._register_handler , loop , handler , timeout )
113
+
102
114
@classmethod
103
115
def _register (
104
116
cls ,
@@ -116,6 +128,18 @@ def _register(
116
128
)
117
129
handle .set_timer (timer_handle )
118
130
131
+ @classmethod
132
+ def _register_handler (
133
+ cls ,
134
+ loop ,
135
+ handler : Callable [[], None ],
136
+ timeout : timedelta ,
137
+ ) -> None :
138
+ loop .call_later (
139
+ timeout .total_seconds (),
140
+ handler ,
141
+ )
142
+
119
143
120
144
_TIMEOUT_MANAGER = _TimeoutManager ()
121
145
@@ -163,3 +187,18 @@ def callback(fut: Future[T]) -> T:
163
187
raise TimeoutError (f"future did not complete within { timeout } " )
164
188
165
189
return fut .wait ()
190
+
191
+
192
+ def stream_timeout (callback : Callable [[], None ], timeout : timedelta ) -> None :
193
+ """
194
+ Registers a callback that will be called after the specified timeout if
195
+ the current stream doesn't complete in time.
196
+
197
+ This uses a cuda Event to track the completion of the current stream. If
198
+ the stream is not complete after the timeout, the callback is called.
199
+
200
+ Args:
201
+ callback: The callback to call if the stream doesn't complete in time.
202
+ timeout: The timeout to wait for the stream to complete.
203
+ """
204
+ _TIMEOUT_MANAGER .stream_timeout (callback , timeout )
0 commit comments