Skip to content

Commit 986e7b0

Browse files
committed
allow for cleanup task in effect for py<3.11
1 parent 3d81311 commit 986e7b0

File tree

2 files changed

+92
-15
lines changed

2 files changed

+92
-15
lines changed

src/py/reactpy/reactpy/core/hooks.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import inspect
54
import sys
65
import warnings
7-
from collections.abc import Coroutine, Sequence
6+
from asyncio import CancelledError, Event, create_task
7+
from collections.abc import Awaitable, Coroutine, Sequence
88
from logging import getLogger
99
from types import FunctionType
1010
from typing import (
@@ -95,9 +95,10 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
9595
self.dispatch = dispatch
9696

9797

98+
_Coro = Coroutine[None, None, _Type]
9899
_EffectCleanFunc: TypeAlias = "Callable[[], None]"
99100
_SyncEffectFunc: TypeAlias = "Callable[[], _EffectCleanFunc | None]"
100-
_AsyncEffectFunc: TypeAlias = "Callable[[Effect], Coroutine[None, None, None]]"
101+
_AsyncEffectFunc: TypeAlias = "Callable[[Effect], _Coro[Awaitable[Any] | None]]"
101102
_EffectFunc: TypeAlias = "_SyncEffectFunc | _AsyncEffectFunc"
102103

103104

@@ -152,8 +153,7 @@ async def start_effect() -> StopEffect:
152153
if effect_ref.current is not None:
153154
await effect_ref.current.stop()
154155

155-
effect = effect_ref.current = Effect()
156-
effect.task = asyncio.create_task(effect_func(effect))
156+
effect = effect_ref.current = Effect(effect_func)
157157
await effect.started()
158158

159159
return effect.stop
@@ -170,26 +170,37 @@ async def start_effect() -> StopEffect:
170170
class Effect:
171171
"""A context manager for running asynchronous effects."""
172172

173-
task: asyncio.Task[Any]
174-
"""The task that is running the effect."""
175-
176-
def __init__(self) -> None:
177-
self._stop = asyncio.Event()
178-
self._started = asyncio.Event()
173+
def __init__(self, effect_func: _AsyncEffectFunc) -> None:
174+
self.task = create_task(effect_func(self))
175+
self._stop = Event()
176+
self._started = Event()
177+
self._stopped = Event()
179178
self._cancel_count = 0
180179

181180
async def stop(self) -> None:
182181
"""Signal the effect to stop."""
182+
if self._stop.is_set():
183+
await self._stopped.wait()
184+
return None
185+
183186
if self._started.is_set():
184187
self._cancel_task()
185188
self._stop.set()
186189
try:
187-
await self.task
188-
except asyncio.CancelledError:
190+
cleanup = await self.task
191+
except CancelledError:
189192
pass
190193
except Exception:
191194
logger.exception("Error while stopping effect")
192195

196+
if cleanup is not None:
197+
try:
198+
await cleanup
199+
except Exception:
200+
logger.exception("Error while cleaning up effect")
201+
202+
self._stopped.set()
203+
193204
async def started(self) -> None:
194205
"""Wait for the effect to start."""
195206
await self._started.wait()
@@ -205,6 +216,7 @@ async def __aenter__(self) -> Self:
205216

206217
if sys.version_info < (3, 11): # nocov
207218
# Python<3.11 doesn't have Task.cancelling so we need to track it ourselves.
219+
# Task.uncancel is a no-op since there's no way to backport the behavior.
208220

209221
async def __aenter__(self) -> Self:
210222
cancel_count = 0
@@ -217,20 +229,22 @@ def new_cancel(*a, **kw) -> None:
217229

218230
self.task.cancel = new_cancel
219231
self.task.cancelling = lambda: cancel_count
232+
self.task.uncancel = lambda: None
220233

221234
return await self._3_11__aenter__()
222235

223236
async def __aexit__(self, exc_type: type[BaseException], *exc: Any) -> Any:
224-
if exc_type is not None and not issubclass(exc_type, asyncio.CancelledError):
237+
if exc_type is not None and not issubclass(exc_type, CancelledError):
225238
# propagate non-cancellation exceptions
226239
return None
227240

228241
try:
229242
await self._stop.wait()
230-
except asyncio.CancelledError:
243+
except CancelledError:
231244
if self.task.cancelling() > self._cancel_count:
232245
# Task has been cancelled by something else - propagate it
233246
return None
247+
self.task.uncancel()
234248

235249
return True
236250

src/py/reactpy/tests/test_core/test_hooks.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import sys
23

34
import pytest
45

@@ -537,6 +538,68 @@ async def effect(e):
537538
await asyncio.wait_for(cleanup_ran.wait(), 1)
538539

539540

541+
@pytest.mark.skipif(
542+
sys.version_info < (3, 11),
543+
reason="asyncio.Task.uncancel does not exist",
544+
)
545+
async def test_use_async_effect_with_await_in_cleanup():
546+
component_hook = HookCatcher()
547+
effect_ran = WaitForEvent()
548+
cleanup_ran = WaitForEvent()
549+
550+
async def cleanup_task():
551+
cleanup_ran.set()
552+
553+
@reactpy.component
554+
@component_hook.capture
555+
def ComponentWithAsyncEffect():
556+
@reactpy.use_effect(dependencies=None) # force this to run every time
557+
async def effect(e):
558+
async with e:
559+
effect_ran.set()
560+
await cleanup_task()
561+
562+
return reactpy.html.div()
563+
564+
async with reactpy.Layout(ComponentWithAsyncEffect()) as layout:
565+
await layout.render()
566+
567+
component_hook.latest.schedule_render()
568+
569+
await layout.render()
570+
571+
await asyncio.wait_for(cleanup_ran.wait(), 1)
572+
573+
574+
async def test_use_async_effect_cleanup_task():
575+
component_hook = HookCatcher()
576+
effect_ran = WaitForEvent()
577+
cleanup_ran = WaitForEvent()
578+
579+
async def cleanup_task():
580+
cleanup_ran.set()
581+
582+
@reactpy.component
583+
@component_hook.capture
584+
def ComponentWithAsyncEffect():
585+
@reactpy.use_effect(dependencies=None) # force this to run every time
586+
async def effect(e):
587+
async with e:
588+
effect_ran.set()
589+
return cleanup_task()
590+
591+
return reactpy.html.div()
592+
593+
async with reactpy.Layout(ComponentWithAsyncEffect()) as layout:
594+
await layout.render()
595+
596+
component_hook.latest.schedule_render()
597+
598+
await layout.render()
599+
600+
await asyncio.wait_for(cleanup_ran.wait(), 1)
601+
602+
540603
async def test_use_async_effect_cancel(caplog):
541604
component_hook = HookCatcher()
542605
effect_ran = WaitForEvent()

0 commit comments

Comments
 (0)