Skip to content

Commit 8332c03

Browse files
committed
Updates for custom records and method updates
1 parent 86a92f9 commit 8332c03

11 files changed

+949
-190
lines changed

asyncpg/cluster.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,15 @@ def get_status(self) -> str:
130130

131131
async def connect(self,
132132
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
133-
**kwargs: typing.Any) -> 'connection.Connection':
133+
**kwargs: typing.Any) \
134+
-> 'connection.Connection[typing.Any]':
134135
conn_info = typing.cast(typing.Dict[str, typing.Any],
135136
self.get_connection_spec())
136137
conn_info.update(kwargs)
137-
return await asyncpg.connect(loop=loop, **conn_info)
138+
return typing.cast(
139+
'connection.Connection[typing.Any]',
140+
await asyncpg.connect(loop=loop, **conn_info)
141+
)
138142

139143
def init(self, **settings: str) -> str:
140144
"""Initialize cluster."""

asyncpg/connect_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
_Connection = typing.TypeVar('_Connection')
3131
_Protocol = typing.TypeVar('_Protocol', bound=asyncio.Protocol)
32+
_Record = typing.TypeVar('_Record', bound=protocol.Record)
3233

3334
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _Protocol]
3435
AddrType = typing.Union[typing.Tuple[str, int], str]
@@ -654,7 +655,7 @@ async def _connect_addr(
654655
params: _ConnectionParameters,
655656
config: _ClientConfiguration,
656657
connection_class: typing.Type[_Connection],
657-
record_class: typing.Any
658+
record_class: typing.Type[_Record]
658659
) -> _Connection:
659660
assert loop is not None
660661

@@ -680,7 +681,7 @@ async def _connect_addr(
680681
assert not params.ssl
681682
connector = typing.cast(
682683
typing.Coroutine[typing.Any, None,
683-
_TPTupleType[protocol.Protocol]],
684+
_TPTupleType['protocol.Protocol[_Record]']],
684685
loop.create_unix_connection(proto_factory, addr))
685686
elif params.ssl:
686687
connector = _create_ssl_connection(
@@ -689,7 +690,7 @@ async def _connect_addr(
689690
else:
690691
connector = typing.cast(
691692
typing.Coroutine[typing.Any, None,
692-
_TPTupleType[protocol.Protocol]],
693+
_TPTupleType['protocol.Protocol[_Record]']],
693694
loop.create_connection(proto_factory, *addr))
694695

695696
connector_future = asyncio.ensure_future(connector)
@@ -721,7 +722,7 @@ async def _connect(
721722
loop: typing.Optional[asyncio.AbstractEventLoop],
722723
timeout: float,
723724
connection_class: typing.Type[_Connection],
724-
record_class: typing.Any,
725+
record_class: typing.Type[_Record],
725726
**kwargs: typing.Any
726727
) -> _Connection:
727728
if loop is None:

0 commit comments

Comments
 (0)