From 45f9f35ed027d1d838ed5a20fa4c793bbaabf40e Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 31 Jan 2025 14:20:18 +0100 Subject: [PATCH 1/3] make context reset-able --- finegrain/src/finegrain/__init__.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index 771abc3..e52635a 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -222,15 +222,12 @@ class EditorAPIContext: token: str | None logger: logging.Logger - max_sse_failures: int _client: httpx.AsyncClient | None _client_ctx_depth: int _sse_futures: Futures[dict[str, Any]] + _sse_source: ResilientEventSource _sse_task: asyncio.Task[None] | None - _sse_failures: int - _sse_last_event_id: str - _sse_retry_ms: int def __init__( self, @@ -248,15 +245,20 @@ def __init__( self.verify = verify self.default_timeout = default_timeout - self.token = None self.logger = logger + self._sse_source = ResilientEventSource(url=self.get_sub_url, verify=self.verify) + self.reset() + def reset(self) -> None: + self.token = None self._client = None self._client_ctx_depth = 0 - self._sse_futures = Futures() - self._sse_source = ResilientEventSource(self.get_sub_url, verify=self.verify) self._sse_task = None + try: + self._sse_source.reset() + except RuntimeError: # outside asyncio + pass async def __aenter__(self) -> httpx.AsyncClient: if self._client: From b99d583b21e7e025594dab7cbad44a98f4e35dc3 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 31 Jan 2025 14:20:32 +0100 Subject: [PATCH 2/3] add server ping interval --- finegrain/src/finegrain/__init__.py | 42 ++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index e52635a..16924af 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -111,12 +111,23 @@ def success(self) -> None: self.failures = 0 +class TimeoutableAsyncIterator[T](AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], timeout: float) -> None: + self.iterator = iterator + self.timeout = timeout + + async def __anext__(self) -> T: + return await asyncio.wait_for(self.iterator.__anext__(), timeout=self.timeout) + + class ResilientEventSource: get_url: Callable[[], Awaitable[str]] + get_ping_interval: Callable[[], Awaitable[float]] verify: bool | str retry_ctx: RetryContext logger: logging.Logger + server_ping_grace_period: float _last_event_id: str _retry_ms: int @@ -126,14 +137,19 @@ class ResilientEventSource: def __init__( self, url: str | Callable[[], Awaitable[str]], + ping_interval: float | Callable[[], Awaitable[float]] = 0.0, verify: bool | str = True, retry_ctx: RetryContext | None = None, ) -> None: self.get_url = self.async_return(url) if isinstance(url, str) else url + if isinstance(ping_interval, int | float): + ping_interval = self.async_return(ping_interval) + self.get_ping_interval = ping_interval self.verify = verify self.retry_ctx = RetryContext() if retry_ctx is None else retry_ctx self.logger = logger + self.server_ping_grace_period = 3.0 def reset(self) -> None: self._last_event_id = "" @@ -188,13 +204,20 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ) await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000) url = await self.get_url() + ping_interval = await self.get_ping_interval() + async with ( httpx.AsyncClient(timeout=None, verify=self.verify) as c, httpx_sse.aconnect_sse(c, "GET", url, headers=self.headers) as es, ): es.response.raise_for_status() self.success() - async for sse in es.aiter_sse(): + if ping_interval > 0: + timeout = ping_interval + self.server_ping_grace_period + it = TimeoutableAsyncIterator(es.aiter_sse(), timeout=timeout) + else: + it = es.aiter_sse() + async for sse in it: self._last_event_id = sse.id self._retry_ms = sse.retry or 0 if sse.event == "ping": @@ -208,7 +231,7 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: continue yield event raise SSELoopStopped(message="SSE loop exited") - except (SSELoopStopped, httpx.HTTPError) as exc: + except (SSELoopStopped, httpx.HTTPError, TimeoutError) as exc: self.failure(exc) @@ -228,6 +251,7 @@ class EditorAPIContext: _sse_futures: Futures[dict[str, Any]] _sse_source: ResilientEventSource _sse_task: asyncio.Task[None] | None + _ping_interval: float def __init__( self, @@ -246,7 +270,11 @@ def __init__( self.default_timeout = default_timeout self.logger = logger - self._sse_source = ResilientEventSource(url=self.get_sub_url, verify=self.verify) + self._sse_source = ResilientEventSource( + url=self.get_sub_url, + ping_interval=self.get_ping_interval, + verify=self.verify, + ) self.reset() def reset(self) -> None: @@ -255,6 +283,7 @@ def reset(self) -> None: self._client_ctx_depth = 0 self._sse_futures = Futures() self._sse_task = None + self._ping_interval = 0.0 try: self._sse_source.reset() except RuntimeError: # outside asyncio @@ -326,9 +355,14 @@ async def _q() -> httpx.Response: async def get_sub_url(self) -> str: response = await self.request("POST", "sub-auth") - sub_token = response.json()["token"] + jdata = response.json() + sub_token = jdata["token"] + self._ping_interval = float(jdata.get("ping_interval", 0.0)) return f"{self.base_url}/sub/{sub_token}" + async def get_ping_interval(self) -> float: + return self._ping_interval + async def _sse_loop(self) -> None: async for event in self._sse_source: if "state" not in event: From 52234282800006d1ad5531d44f126d7fae25fcdc Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 31 Jan 2025 14:57:34 +0100 Subject: [PATCH 3/3] add last error in SSE retry log --- finegrain/src/finegrain/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index 16924af..d677c02 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -200,7 +200,8 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: if self.retry_ctx.failures > 0: self.logger.info( f"SSE loop retry attempt {self.retry_ctx.failures} " - f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms})" + f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms}, " + f"last error {self.retry_ctx.last_error})" ) await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000) url = await self.get_url() @@ -232,6 +233,8 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: yield event raise SSELoopStopped(message="SSE loop exited") except (SSELoopStopped, httpx.HTTPError, TimeoutError) as exc: + if isinstance(exc, TimeoutError) and not str(exc): + exc = TimeoutError("timeout") self.failure(exc)