1
1
from __future__ import annotations
2
2
3
- import asyncio
4
3
import inspect
5
4
import sys
6
5
import warnings
7
- from collections .abc import Coroutine , Sequence
6
+ from asyncio import CancelledError , Event , create_task
7
+ from collections .abc import Awaitable , Coroutine , Sequence
8
8
from logging import getLogger
9
9
from types import FunctionType
10
10
from typing import (
@@ -95,9 +95,10 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
95
95
self .dispatch = dispatch
96
96
97
97
98
+ _Coro = Coroutine [None , None , _Type ]
98
99
_EffectCleanFunc : TypeAlias = "Callable[[], None]"
99
100
_SyncEffectFunc : TypeAlias = "Callable[[], _EffectCleanFunc | None]"
100
- _AsyncEffectFunc : TypeAlias = "Callable[[Effect], Coroutine[None, None, None]]"
101
+ _AsyncEffectFunc : TypeAlias = "Callable[[Effect], _Coro[Awaitable[Any] | None]]"
101
102
_EffectFunc : TypeAlias = "_SyncEffectFunc | _AsyncEffectFunc"
102
103
103
104
@@ -152,8 +153,7 @@ async def start_effect() -> StopEffect:
152
153
if effect_ref .current is not None :
153
154
await effect_ref .current .stop ()
154
155
155
- effect = effect_ref .current = Effect ()
156
- effect .task = asyncio .create_task (effect_func (effect ))
156
+ effect = effect_ref .current = Effect (effect_func )
157
157
await effect .started ()
158
158
159
159
return effect .stop
@@ -170,26 +170,37 @@ async def start_effect() -> StopEffect:
170
170
class Effect :
171
171
"""A context manager for running asynchronous effects."""
172
172
173
- task : asyncio .Task [Any ]
174
- """The task that is running the effect."""
175
-
176
- def __init__ (self ) -> None :
177
- self ._stop = asyncio .Event ()
178
- self ._started = asyncio .Event ()
173
+ def __init__ (self , effect_func : _AsyncEffectFunc ) -> None :
174
+ self .task = create_task (effect_func (self ))
175
+ self ._stop = Event ()
176
+ self ._started = Event ()
177
+ self ._stopped = Event ()
179
178
self ._cancel_count = 0
180
179
181
180
async def stop (self ) -> None :
182
181
"""Signal the effect to stop."""
182
+ if self ._stop .is_set ():
183
+ await self ._stopped .wait ()
184
+ return None
185
+
183
186
if self ._started .is_set ():
184
187
self ._cancel_task ()
185
188
self ._stop .set ()
186
189
try :
187
- await self .task
188
- except asyncio . CancelledError :
190
+ cleanup = await self .task
191
+ except CancelledError :
189
192
pass
190
193
except Exception :
191
194
logger .exception ("Error while stopping effect" )
192
195
196
+ if cleanup is not None :
197
+ try :
198
+ await cleanup
199
+ except Exception :
200
+ logger .exception ("Error while cleaning up effect" )
201
+
202
+ self ._stopped .set ()
203
+
193
204
async def started (self ) -> None :
194
205
"""Wait for the effect to start."""
195
206
await self ._started .wait ()
@@ -205,6 +216,7 @@ async def __aenter__(self) -> Self:
205
216
206
217
if sys .version_info < (3 , 11 ): # nocov
207
218
# Python<3.11 doesn't have Task.cancelling so we need to track it ourselves.
219
+ # Task.uncancel is a no-op since there's no way to backport the behavior.
208
220
209
221
async def __aenter__ (self ) -> Self :
210
222
cancel_count = 0
@@ -217,20 +229,22 @@ def new_cancel(*a, **kw) -> None:
217
229
218
230
self .task .cancel = new_cancel
219
231
self .task .cancelling = lambda : cancel_count
232
+ self .task .uncancel = lambda : None
220
233
221
234
return await self ._3_11__aenter__ ()
222
235
223
236
async def __aexit__ (self , exc_type : type [BaseException ], * exc : Any ) -> Any :
224
- if exc_type is not None and not issubclass (exc_type , asyncio . CancelledError ):
237
+ if exc_type is not None and not issubclass (exc_type , CancelledError ):
225
238
# propagate non-cancellation exceptions
226
239
return None
227
240
228
241
try :
229
242
await self ._stop .wait ()
230
- except asyncio . CancelledError :
243
+ except CancelledError :
231
244
if self .task .cancelling () > self ._cancel_count :
232
245
# Task has been cancelled by something else - propagate it
233
246
return None
247
+ self .task .uncancel ()
234
248
235
249
return True
236
250
0 commit comments