Skip to content

Commit 7e78d65

Browse files
committed
use Coroutine in conn_retry
Signed-off-by: maximsmol <[email protected]>
1 parent 10b87f5 commit 7e78d65

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/latch_postgres/postgres.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import functools
33
import random
4-
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
4+
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterable
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass
77
from datetime import timedelta
@@ -52,6 +52,8 @@
5252

5353
from latch_postgres.retries import CABackoff
5454

55+
YT = TypeVar("YT")
56+
ST = TypeVar("ST")
5557
T = TypeVar("T")
5658

5759
tracer = get_tracer(__name__)
@@ -455,10 +457,10 @@ def pg_error_to_dict(x: PGError, *, short: bool = False):
455457

456458

457459
def with_conn_retry(
458-
f: Callable[Concatenate[LatchAsyncConnection[Any], P], Awaitable[T]],
460+
f: Callable[Concatenate[LatchAsyncConnection[Any], P], Coroutine[YT, ST, T]],
459461
pool: AsyncConnectionPool,
460462
db_config: PostgresConnectionConfig,
461-
) -> Callable[P, Awaitable[T]]:
463+
) -> Callable[P, Coroutine[YT, ST, T]]:
462464
@functools.wraps(f)
463465
async def inner(*args: P.args, **kwargs: P.kwargs):
464466
with tracer.start_as_current_span("database session") as s:
@@ -585,8 +587,8 @@ async def inner(*args: P.args, **kwargs: P.kwargs):
585587
def get_with_conn_retry(
586588
pool: AsyncConnectionPool, db_config: PostgresConnectionConfig
587589
) -> Callable[
588-
[Callable[Concatenate[LatchAsyncConnection[Any], P], Awaitable[T]]],
589-
Callable[P, Awaitable[T]],
590+
[Callable[Concatenate[LatchAsyncConnection[Any], P], Coroutine[YT, ST, T]]],
591+
Callable[P, Coroutine[YT, ST, T]],
590592
]:
591593
return functools.partial(with_conn_retry, pool=pool, db_config=db_config)
592594

0 commit comments

Comments
 (0)