Skip to content

Commit 86a92f9

Browse files
committed
Address feedback and fix runtime errors
1 parent 59b24b9 commit 86a92f9

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

asyncpg/cluster.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ def get_status(self) -> str:
131131
async def connect(self,
132132
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
133133
**kwargs: typing.Any) -> 'connection.Connection':
134-
conn_info: typing.Optional[typing.Any] = self.get_connection_spec()
135-
conn_info.update(kwargs) # type: ignore[union-attr]
136-
return await asyncpg.connect(loop=loop, **conn_info) # type: ignore[misc] # noqa: E501
134+
conn_info = typing.cast(typing.Dict[str, typing.Any],
135+
self.get_connection_spec())
136+
conn_info.update(kwargs)
137+
return await asyncpg.connect(loop=loop, **conn_info)
137138

138139
def init(self, **settings: str) -> str:
139140
"""Initialize cluster."""
@@ -307,12 +308,17 @@ def _get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
307308

308309
return None
309310

310-
def get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
311+
def get_connection_spec(self) -> _ConnectionSpec:
311312
status = self.get_status()
312313
if status != 'running':
313314
raise ClusterError('cluster is not running')
314315

315-
return self._get_connection_spec()
316+
spec = self._get_connection_spec()
317+
318+
if spec is None:
319+
raise ClusterError('cannot determine server connection address')
320+
321+
return spec
316322

317323
def override_connection_spec(self, **kwargs: str) -> None:
318324
self._connection_spec_override = typing.cast(_ConnectionSpec, kwargs)

asyncpg/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
_AnyCallable = typing.Callable[..., typing.Any]
4646

4747
OutputType = typing.Union[typing.AnyStr,
48-
os.PathLike[typing.AnyStr],
48+
os.PathLike,
4949
typing.IO[typing.AnyStr],
5050
_Writer]
5151
SourceType = typing.Union[typing.AnyStr,
52-
os.PathLike[typing.AnyStr],
52+
os.PathLike,
5353
typing.IO[typing.AnyStr],
5454
typing.AsyncIterable[bytes]]
5555

tests/test_copy.py

-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import asyncpg
1515
from asyncpg import _testbase as tb
16-
from asyncpg import compat
1716

1817

1918
class TestCopyFrom(tb.ConnectedTestCase):
@@ -467,7 +466,6 @@ class _Source:
467466
def __init__(self):
468467
self.rowcount = 0
469468

470-
@compat.aiter_compat
471469
def __aiter__(self):
472470
return self
473471

@@ -507,7 +505,6 @@ class _Source:
507505
def __init__(self):
508506
self.rowcount = 0
509507

510-
@compat.aiter_compat
511508
def __aiter__(self):
512509
return self
513510

@@ -533,7 +530,6 @@ class _Source:
533530
def __init__(self):
534531
self.rowcount = 0
535532

536-
@compat.aiter_compat
537533
def __aiter__(self):
538534
return self
539535

@@ -564,7 +560,6 @@ def __init__(self, loop):
564560
self.rowcount = 0
565561
self.loop = loop
566562

567-
@compat.aiter_compat
568563
def __aiter__(self):
569564
return self
570565

0 commit comments

Comments
 (0)