Skip to content

Commit

Permalink
add server ping interval
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Jan 31, 2025
1 parent 133ed73 commit 58572c9
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions finegrain/src/finegrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,23 @@ def success(self) -> None:
self.failures = 0


class TimeoutableAsyncIterator[T](AsyncIterator[T]):
def __init__(self, iterator: AsyncIterator[T], timeout: float) -> None:
self.iterator = iterator
self.timeout = timeout

async def __anext__(self) -> T:
return await asyncio.wait_for(self.iterator.__anext__(), timeout=self.timeout)


class ResilientEventSource:
get_url: Callable[[], Awaitable[str]]
get_ping_interval: Callable[[], Awaitable[float]]
verify: bool | str
retry_ctx: RetryContext

logger: logging.Logger
server_ping_grace_period: float

_last_event_id: str
_retry_ms: int
Expand All @@ -126,14 +137,19 @@ class ResilientEventSource:
def __init__(
self,
url: str | Callable[[], Awaitable[str]],
ping_interval: float | Callable[[], Awaitable[float]] = 0.0,
verify: bool | str = True,
retry_ctx: RetryContext | None = None,
) -> None:
self.get_url = self.async_return(url) if isinstance(url, str) else url
if isinstance(ping_interval, int | float):
ping_interval = self.async_return(ping_interval)
self.get_ping_interval = ping_interval
self.verify = verify
self.retry_ctx = RetryContext() if retry_ctx is None else retry_ctx

self.logger = logger
self.server_ping_grace_period = 3.0

def reset(self) -> None:
self._last_event_id = ""
Expand Down Expand Up @@ -188,13 +204,20 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
)
await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000)
url = await self.get_url()
ping_interval = await self.get_ping_interval()

async with (
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
httpx_sse.aconnect_sse(c, "GET", url, headers=self.headers) as es,
):
es.response.raise_for_status()
self.success()
async for sse in es.aiter_sse():
if ping_interval > 0:
timeout = ping_interval + self.server_ping_grace_period
it = TimeoutableAsyncIterator(es.aiter_sse(), timeout=timeout)
else:
it = es.aiter_sse()
async for sse in it:
self._last_event_id = sse.id
self._retry_ms = sse.retry or 0
if sse.event == "ping":
Expand All @@ -208,7 +231,7 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
continue
yield event
raise SSELoopStopped(message="SSE loop exited")
except (SSELoopStopped, httpx.HTTPError) as exc:
except (SSELoopStopped, httpx.HTTPError, TimeoutError) as exc:
self.failure(exc)


Expand All @@ -228,6 +251,7 @@ class EditorAPIContext:
_sse_futures: Futures[dict[str, Any]]
_sse_source: ResilientEventSource
_sse_task: asyncio.Task[None] | None
_ping_interval: float

def __init__(
self,
Expand All @@ -246,7 +270,11 @@ def __init__(
self.default_timeout = default_timeout

self.logger = logger
self._sse_source = ResilientEventSource(url=self.get_sub_url, verify=self.verify)
self._sse_source = ResilientEventSource(
url=self.get_sub_url,
ping_interval=self.get_ping_interval,
verify=self.verify,
)
self.reset()

def reset(self) -> None:
Expand All @@ -255,6 +283,7 @@ def reset(self) -> None:
self._client_ctx_depth = 0
self._sse_futures = Futures()
self._sse_task = None
self._ping_interval = 0.0
try:
self._sse_source.reset()
except RuntimeError: # outside asyncio
Expand Down Expand Up @@ -326,9 +355,14 @@ async def _q() -> httpx.Response:

async def get_sub_url(self) -> str:
response = await self.request("POST", "sub-auth")
sub_token = response.json()["token"]
jdata = response.json()
sub_token = jdata["token"]
self._ping_interval = float(jdata.get("ping_interval", 0.0))
return f"{self.base_url}/sub/{sub_token}"

async def get_ping_interval(self) -> float:
return self._ping_interval

async def _sse_loop(self) -> None:
async for event in self._sse_source:
if "state" not in event:
Expand Down

0 comments on commit 58572c9

Please sign in to comment.