Skip to content

Commit

Permalink
Merge pull request #11 from finegrain-ai/pr/server-ping
Browse files Browse the repository at this point in the history
add server ping
  • Loading branch information
catwell authored Feb 3, 2025
2 parents 73bb403 + 5223428 commit f3e84d8
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions finegrain/src/finegrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,23 @@ def success(self) -> None:
self.failures = 0


class TimeoutableAsyncIterator[T](AsyncIterator[T]):
def __init__(self, iterator: AsyncIterator[T], timeout: float) -> None:
self.iterator = iterator
self.timeout = timeout

async def __anext__(self) -> T:
return await asyncio.wait_for(self.iterator.__anext__(), timeout=self.timeout)


class ResilientEventSource:
get_url: Callable[[], Awaitable[str]]
get_ping_interval: Callable[[], Awaitable[float]]
verify: bool | str
retry_ctx: RetryContext

logger: logging.Logger
server_ping_grace_period: float

_last_event_id: str
_retry_ms: int
Expand All @@ -126,14 +137,19 @@ class ResilientEventSource:
def __init__(
self,
url: str | Callable[[], Awaitable[str]],
ping_interval: float | Callable[[], Awaitable[float]] = 0.0,
verify: bool | str = True,
retry_ctx: RetryContext | None = None,
) -> None:
self.get_url = self.async_return(url) if isinstance(url, str) else url
if isinstance(ping_interval, int | float):
ping_interval = self.async_return(ping_interval)
self.get_ping_interval = ping_interval
self.verify = verify
self.retry_ctx = RetryContext() if retry_ctx is None else retry_ctx

self.logger = logger
self.server_ping_grace_period = 3.0

def reset(self) -> None:
self._last_event_id = ""
Expand Down Expand Up @@ -184,17 +200,25 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
if self.retry_ctx.failures > 0:
self.logger.info(
f"SSE loop retry attempt {self.retry_ctx.failures} "
f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms})"
f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms}, "
f"last error {self.retry_ctx.last_error})"
)
await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000)
url = await self.get_url()
ping_interval = await self.get_ping_interval()

async with (
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
httpx_sse.aconnect_sse(c, "GET", url, headers=self.headers) as es,
):
es.response.raise_for_status()
self.success()
async for sse in es.aiter_sse():
if ping_interval > 0:
timeout = ping_interval + self.server_ping_grace_period
it = TimeoutableAsyncIterator(es.aiter_sse(), timeout=timeout)
else:
it = es.aiter_sse()
async for sse in it:
self._last_event_id = sse.id
self._retry_ms = sse.retry or 0
if sse.event == "ping":
Expand All @@ -208,7 +232,9 @@ async def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
continue
yield event
raise SSELoopStopped(message="SSE loop exited")
except (SSELoopStopped, httpx.HTTPError) as exc:
except (SSELoopStopped, httpx.HTTPError, TimeoutError) as exc:
if isinstance(exc, TimeoutError) and not str(exc):
exc = TimeoutError("timeout")
self.failure(exc)


Expand All @@ -222,15 +248,13 @@ class EditorAPIContext:

token: str | None
logger: logging.Logger
max_sse_failures: int

_client: httpx.AsyncClient | None
_client_ctx_depth: int
_sse_futures: Futures[dict[str, Any]]
_sse_source: ResilientEventSource
_sse_task: asyncio.Task[None] | None
_sse_failures: int
_sse_last_event_id: str
_sse_retry_ms: int
_ping_interval: float

def __init__(
self,
Expand All @@ -248,15 +272,25 @@ def __init__(
self.verify = verify
self.default_timeout = default_timeout

self.token = None
self.logger = logger
self._sse_source = ResilientEventSource(
url=self.get_sub_url,
ping_interval=self.get_ping_interval,
verify=self.verify,
)
self.reset()

def reset(self) -> None:
self.token = None
self._client = None
self._client_ctx_depth = 0

self._sse_futures = Futures()
self._sse_source = ResilientEventSource(self.get_sub_url, verify=self.verify)
self._sse_task = None
self._ping_interval = 0.0
try:
self._sse_source.reset()
except RuntimeError: # outside asyncio
pass

async def __aenter__(self) -> httpx.AsyncClient:
if self._client:
Expand Down Expand Up @@ -326,9 +360,14 @@ async def _q() -> httpx.Response:

async def get_sub_url(self) -> str:
response = await self.request("POST", "sub-auth")
sub_token = response.json()["token"]
jdata = response.json()
sub_token = jdata["token"]
self._ping_interval = float(jdata.get("ping_interval", 0.0))
return f"{self.base_url}/sub/{sub_token}"

async def get_ping_interval(self) -> float:
return self._ping_interval

async def _sse_loop(self) -> None:
async for event in self._sse_source:
if "state" not in event:
Expand Down

0 comments on commit f3e84d8

Please sign in to comment.