|
1 | 1 | import asyncio
|
2 | 2 | import functools
|
3 | 3 | import random
|
4 |
| -from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable |
| 4 | +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Iterable |
5 | 5 | from contextlib import asynccontextmanager
|
6 | 6 | from dataclasses import dataclass
|
7 | 7 | from datetime import timedelta
|
|
52 | 52 |
|
53 | 53 | from latch_postgres.retries import CABackoff
|
54 | 54 |
|
| 55 | +YT = TypeVar("YT") |
| 56 | +ST = TypeVar("ST") |
55 | 57 | T = TypeVar("T")
|
56 | 58 |
|
57 | 59 | tracer = get_tracer(__name__)
|
@@ -455,10 +457,10 @@ def pg_error_to_dict(x: PGError, *, short: bool = False):
|
455 | 457 |
|
456 | 458 |
|
457 | 459 | def with_conn_retry(
|
458 |
| - f: Callable[Concatenate[LatchAsyncConnection[Any], P], Awaitable[T]], |
| 460 | + f: Callable[Concatenate[LatchAsyncConnection[Any], P], Coroutine[YT, ST, T]], |
459 | 461 | pool: AsyncConnectionPool,
|
460 | 462 | db_config: PostgresConnectionConfig,
|
461 |
| -) -> Callable[P, Awaitable[T]]: |
| 463 | +) -> Callable[P, Coroutine[YT, ST, T]]: |
462 | 464 | @functools.wraps(f)
|
463 | 465 | async def inner(*args: P.args, **kwargs: P.kwargs):
|
464 | 466 | with tracer.start_as_current_span("database session") as s:
|
@@ -585,8 +587,8 @@ async def inner(*args: P.args, **kwargs: P.kwargs):
|
585 | 587 | def get_with_conn_retry(
|
586 | 588 | pool: AsyncConnectionPool, db_config: PostgresConnectionConfig
|
587 | 589 | ) -> Callable[
|
588 |
| - [Callable[Concatenate[LatchAsyncConnection[Any], P], Awaitable[T]]], |
589 |
| - Callable[P, Awaitable[T]], |
| 590 | + [Callable[Concatenate[LatchAsyncConnection[Any], P], Coroutine[YT, ST, T]]], |
| 591 | + Callable[P, Coroutine[YT, ST, T]], |
590 | 592 | ]:
|
591 | 593 | return functools.partial(with_conn_retry, pool=pool, db_config=db_config)
|
592 | 594 |
|
|
0 commit comments