Skip to content

Commit d449a6f

Browse files
committed
Use specific ignore comments and updates for mypy 0.780
1 parent 38b8c60 commit d449a6f

9 files changed

+56
-48
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ docs/_build
3434
/.eggs
3535
/.vscode
3636
/.mypy_cache
37-
/.venv
37+
/.venv*
3838
/.ci
3939
/.vim

asyncpg/cluster.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ def get_status(self) -> str:
126126
return self._test_connection(timeout=0)
127127
else:
128128
raise ClusterError(
129-
'pg_ctl status exited with status {:d}: {}'.format(
129+
'pg_ctl status exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
130130
process.returncode, stderr))
131131

132132
async def connect(self,
133133
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
134134
**kwargs: typing.Any) -> 'connection.Connection':
135135
conn_info = self.get_connection_spec() # type: typing.Optional[typing.Any] # noqa: E501
136-
conn_info.update(kwargs)
137-
return await asyncpg.connect(loop=loop, **conn_info)
136+
conn_info.update(kwargs) # type: ignore[union-attr]
137+
return await asyncpg.connect(loop=loop, **conn_info) # type: ignore[misc] # noqa: E501
138138

139139
def init(self, **settings: str) -> str:
140140
"""Initialize cluster."""
@@ -301,7 +301,7 @@ def _get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
301301
if self._connection_addr is not None:
302302
if self._connection_spec_override:
303303
args = self._connection_addr.copy()
304-
args.update(self._connection_spec_override) # type: ignore
304+
args.update(self._connection_spec_override) # type: ignore[arg-type] # noqa: E501
305305
return args
306306
else:
307307
return self._connection_addr
@@ -401,7 +401,7 @@ def add_hba_entry(self, *, type: str = 'host',
401401

402402
if auth_options is not None:
403403
record += ' ' + ' '.join(
404-
'{}={}'.format(k, v) for k, v in auth_options)
404+
'{}={}'.format(k, v) for k, v in auth_options.items())
405405

406406
try:
407407
with open(pg_hba, 'a') as f:
@@ -516,7 +516,7 @@ def _test_connection(self, timeout: int = 60) -> str:
516516

517517
try:
518518
con = loop.run_until_complete(
519-
asyncpg.connect(database='postgres',
519+
asyncpg.connect(database='postgres', # type: ignore[misc] # noqa: E501
520520
user='postgres',
521521
timeout=5, loop=loop,
522522
**self._connection_addr))
@@ -544,7 +544,7 @@ def _run_pg_config(self, pg_config_path: str) -> typing.Dict[str, str]:
544544
stdout, stderr = process.stdout, process.stderr
545545

546546
if process.returncode != 0:
547-
raise ClusterError('pg_config exited with status {:d}: {}'.format(
547+
raise ClusterError('pg_config exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
548548
process.returncode, stderr))
549549
else:
550550
config = {}
@@ -601,7 +601,7 @@ def _get_pg_version(self) -> 'types.ServerVersion':
601601

602602
if process.returncode != 0:
603603
raise ClusterError(
604-
'postgres --version exited with status {:d}: {}'.format(
604+
'postgres --version exited with status {:d}: {}'.format( # type: ignore[str-bytes-safe] # noqa: E501
605605
process.returncode, stderr))
606606

607607
version_string = stdout.decode('utf-8').strip(' \n')

asyncpg/compat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def wrapper(self: typing.Any) -> typing.Any:
3434
return func(self)
3535
return typing.cast(_F_35, wrapper)
3636
else:
37-
def aiter_compat(func: _F) -> _F: # type: ignore
37+
def aiter_compat(func: _F) -> _F: # type: ignore[misc]
3838
return func
3939

4040

@@ -88,7 +88,7 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
8888
# home directory, whereas Postgres stores its config in
8989
# %AppData% on Windows.
9090
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
91-
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) # type: ignore # noqa: E501
91+
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) # type: ignore[attr-defined] # noqa: E501
9292
if r:
9393
return None
9494
else:

asyncpg/connect_utils.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
317317
if 'sslmode' in query_str:
318318
val_str = query_str.pop('sslmode')
319319
if ssl is None:
320-
ssl = val_str
320+
ssl = val_str # type: ignore[assignment]
321321

322322
if query_str:
323323
if server_settings is None:
@@ -392,17 +392,17 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
392392
if passfile is None:
393393
homedir = compat.get_pg_home_directory()
394394
if homedir:
395-
passfile = homedir / PGPASSFILE
395+
passfile = homedir / PGPASSFILE # type: ignore[assignment]
396396
else:
397397
passfile = None
398398
else:
399-
passfile = pathlib.Path(passfile)
399+
passfile = pathlib.Path(passfile) # type: ignore[assignment]
400400

401401
if passfile is not None:
402402
password = _read_password_from_pgpass(
403403
hosts=auth_hosts, ports=port,
404404
database=database, user=user,
405-
passfile=passfile)
405+
passfile=passfile) # type: ignore[arg-type]
406406

407407
addrs = [] # type: typing.List[AddrType]
408408
for h, p in zip(host, port):
@@ -420,7 +420,7 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
420420
'could not determine the database address to connect to')
421421

422422
if ssl is None:
423-
ssl = os.getenv('PGSSLMODE')
423+
ssl = os.getenv('PGSSLMODE') # type: ignore[assignment]
424424

425425
# ssl_is_advisory is only allowed to come from the sslmode parameter.
426426
ssl_is_advisory = None
@@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory: typing.Callable[[],
594594
typing.Tuple[asyncio.WriteTransport, TLSUpgradeProto],
595595
await loop.create_connection(
596596
lambda: TLSUpgradeProto(loop, host, port,
597-
typing.cast(ssl_module.SSLContext,
597+
typing.cast(typing.Any,
598598
ssl_context),
599599
ssl_is_advisory),
600600
host, port))
@@ -614,7 +614,7 @@ async def _create_ssl_connection(protocol_factory: typing.Callable[[],
614614
asyncio.WriteTransport,
615615
await typing.cast(typing.Any, loop).start_tls(
616616
tr, pr,
617-
typing.cast(ssl_module.SSLContext, ssl_context),
617+
typing.cast(typing.Any, ssl_context),
618618
server_hostname=host))
619619
except (Exception, asyncio.CancelledError):
620620
tr.close()
@@ -711,7 +711,7 @@ async def _connect_addr(*, addr: AddrType,
711711
tr.close()
712712
raise
713713

714-
con = connection_class(pr, tr, loop, addr, config, # type: ignore
714+
con = connection_class(pr, tr, loop, addr, config, # type: ignore[call-arg] # noqa: E501
715715
params_input)
716716
pr.set_connection(con)
717717
return con
@@ -805,7 +805,7 @@ def _set_nodelay(sock: typing.Any) -> None:
805805
def _create_future(loop: typing.Optional[asyncio.AbstractEventLoop]) \
806806
-> 'asyncio.Future[typing.Any]':
807807
try:
808-
create_future = loop.create_future # type: ignore
808+
create_future = loop.create_future # type: ignore[union-attr]
809809
except AttributeError:
810810
return asyncio.Future(loop=loop)
811811
else:

asyncpg/connection.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_Connection = typing.TypeVar('_Connection', bound='Connection')
3939
_Writer = typing.Callable[[bytes],
4040
typing.Coroutine[typing.Any, typing.Any, None]]
41+
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')
4142
_RecordsType = typing.List['_cprotocol.Record']
4243
_RecordsExtraType = typing.Tuple[_RecordsType, bytes, bool]
4344
_AnyCallable = typing.Callable[..., typing.Any]
@@ -447,7 +448,8 @@ async def _introspect_types(self, typeoids: typing.Set[int],
447448

448449
def cursor(self, query: str, *args: typing.Any,
449450
prefetch: typing.Optional[int] = None,
450-
timeout: typing.Optional[float] = None) -> cursor.CursorFactory:
451+
timeout: typing.Optional[float] = None) \
452+
-> 'cursor.CursorFactory[_cprotocol.Record]':
451453
"""Return a *cursor factory* for the specified query.
452454
453455
:param args: Query arguments.
@@ -463,7 +465,7 @@ def cursor(self, query: str, *args: typing.Any,
463465

464466
async def prepare(self, query: str, *,
465467
timeout: typing.Optional[float] = None) \
466-
-> prepared_stmt.PreparedStatement:
468+
-> prepared_stmt.PreparedStatement['_cprotocol.Record']:
467469
"""Create a *prepared statement* for the specified query.
468470
469471
:param str query: Text of the query to create a prepared statement for.
@@ -476,7 +478,7 @@ async def prepare(self, query: str, *,
476478
async def _prepare(self, query: str, *,
477479
timeout: typing.Optional[float] = None,
478480
use_cache: bool = False) \
479-
-> prepared_stmt.PreparedStatement:
481+
-> prepared_stmt.PreparedStatement['_cprotocol.Record']:
480482
self._check_open()
481483
stmt = await self._get_statement(query, timeout, named=True,
482484
use_cache=use_cache)
@@ -886,7 +888,7 @@ async def _copy_out(self, copy_stmt: str,
886888
output: OutputType[typing.AnyStr],
887889
timeout: typing.Optional[float]) -> str:
888890
try:
889-
path = compat.fspath(output) # type: typing.Optional[typing.AnyStr] # type: ignore # noqa: E501
891+
path = compat.fspath(output) # type: typing.Optional[typing.AnyStr] # type: ignore[arg-type] # noqa: E501
890892
except TypeError:
891893
# output is not a path-like object
892894
path = None
@@ -913,7 +915,7 @@ async def _copy_out(self, copy_stmt: str,
913915
)
914916

915917
if writer is None:
916-
async def _writer(data: bytes) -> None: # type: ignore
918+
async def _writer(data: bytes) -> None: # type: ignore[return]
917919
await run_in_executor(None, f.write, data)
918920

919921
writer = _writer
@@ -928,7 +930,7 @@ async def _copy_in(self, copy_stmt: str,
928930
source: SourceType[typing.AnyStr],
929931
timeout: typing.Optional[float]) -> str:
930932
try:
931-
path = compat.fspath(source) # type: typing.Optional[typing.AnyStr] # type: ignore # noqa: E501
933+
path = compat.fspath(source) # type: typing.Optional[typing.AnyStr] # type: ignore[arg-type] # noqa: E501
932934
except TypeError:
933935
# source is not a path-like object
934936
path = None
@@ -967,7 +969,7 @@ async def __anext__(self) -> bytes:
967969
if len(data) == 0:
968970
raise StopAsyncIteration
969971
else:
970-
return data # type: ignore
972+
return data # type: ignore[return-value]
971973

972974
reader = _Reader()
973975

@@ -1259,7 +1261,7 @@ def _abort(self) -> None:
12591261
# Put the connection into the aborted state.
12601262
self._aborted = True
12611263
self._protocol.abort()
1262-
self._protocol = None # type: ignore
1264+
self._protocol = None # type: ignore[assignment]
12631265

12641266
def _cleanup(self) -> None:
12651267
# Free the resources associated with this connection.
@@ -1352,7 +1354,7 @@ async def _cancel(self, waiter: 'asyncio.Future[None]') -> None:
13521354
waiter.set_exception(ex)
13531355
finally:
13541356
self._cancellations.discard(
1355-
compat.current_asyncio_task(self._loop))
1357+
compat.current_asyncio_task(self._loop)) # type: ignore[arg-type] # noqa: E501
13561358
if not waiter.done():
13571359
waiter.set_result(None)
13581360

@@ -1747,7 +1749,7 @@ async def connect(dsn: typing.Optional[str] = None, *,
17471749
max_cacheable_statement_size: int = 1024 * 15,
17481750
command_timeout: typing.Optional[float] = None,
17491751
ssl: typing.Optional[connect_utils.SSLType] = None,
1750-
connection_class: typing.Type[_Connection] = Connection, # type: ignore # noqa: E501
1752+
connection_class: typing.Type[_Connection] = Connection, # type: ignore[assignment] # noqa: E501
17511753
server_settings: typing.Optional[
17521754
typing.Dict[str, str]] = None) -> _Connection:
17531755
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -2180,15 +2182,15 @@ def _extract_stack(limit: int = 10) -> str:
21802182
frame = sys._getframe().f_back
21812183
try:
21822184
stack = traceback.StackSummary.extract(
2183-
traceback.walk_stack(frame), lookup_lines=False) # type: typing.Union[traceback.StackSummary, typing.List[traceback.FrameSummary]] # noqa: E501
2185+
traceback.walk_stack(frame), lookup_lines=False) # type: ignore[arg-type] # noqa: E501
21842186
finally:
21852187
del frame
21862188

2187-
apg_path = asyncpg.__path__[0]
2189+
apg_path = asyncpg.__path__[0] # type: ignore[attr-defined]
21882190
i = 0
21892191
while i < len(stack) and stack[i][0].startswith(apg_path):
21902192
i += 1
2191-
stack = stack[i:i + limit]
2193+
stack = stack[i:i + limit] # type: ignore[assignment]
21922194

21932195
stack.reverse()
21942196
return ''.join(traceback.format_list(stack))

asyncpg/cursor.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')
2525

2626

27-
class CursorFactory(connresource.ConnectionResource):
27+
class CursorFactory(connresource.ConnectionResource, typing.Generic[_Record]):
2828
"""A cursor interface for the results of a query.
2929
3030
A cursor interface can be used to initiate efficient traversal of the
@@ -49,7 +49,7 @@ def __init__(self, connection: '_connection.Connection', query: str,
4949

5050
@compat.aiter_compat
5151
@connresource.guarded
52-
def __aiter__(self) -> 'CursorIterator[_cprotocol.Record]':
52+
def __aiter__(self) -> 'CursorIterator[_Record]':
5353
prefetch = 50 if self._prefetch is None else self._prefetch
5454
return CursorIterator(self._connection,
5555
self._query, self._state,
@@ -58,13 +58,13 @@ def __aiter__(self) -> 'CursorIterator[_cprotocol.Record]':
5858

5959
@connresource.guarded
6060
def __await__(self) -> typing.Generator[
61-
typing.Any, None, 'Cursor[_cprotocol.Record]']:
61+
typing.Any, None, 'Cursor[_Record]']:
6262
if self._prefetch is not None:
6363
raise exceptions.InterfaceError(
6464
'prefetch argument can only be specified for iterable cursor')
6565
cursor = Cursor(self._connection, self._query,
6666
self._state,
67-
self._args) # type: Cursor[_cprotocol.Record]
67+
self._args) # type: Cursor[_Record]
6868
return cursor._init(self._timeout).__await__()
6969

7070
def __del__(self) -> None:
@@ -166,7 +166,7 @@ def __repr__(self) -> str:
166166

167167
return '<{}.{} "{!s:.30}" {}{:#x}>'.format(
168168
mod, self.__class__.__name__,
169-
self._state.query,
169+
self._state.query, # type: ignore[union-attr]
170170
' '.join(attrs), id(self))
171171

172172
def __del__(self) -> None:

asyncpg/pool.pyi

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ class PoolConnectionProxy(connection._ConnectionProxy,
4343
*, timeout: typing.Optional[float] = ...) -> None: ...
4444
def cursor(self, query: str, *args: typing.Any,
4545
prefetch: typing.Optional[int] = ...,
46-
timeout: typing.Optional[float] = ...) -> cursor.CursorFactory: ...
46+
timeout: typing.Optional[float] = ...) \
47+
-> cursor.CursorFactory[_cprotocol.Record]: ...
4748
async def prepare(self, query: str, *,
4849
timeout: typing.Optional[float] = ...) \
49-
-> prepared_stmt.PreparedStatement: ...
50+
-> prepared_stmt.PreparedStatement[_cprotocol.Record]: ...
5051
async def fetch(self, query: str, *args: typing.Any,
5152
timeout: typing.Optional[float] = ...) \
5253
-> typing.List[_cprotocol.Record]: ...

asyncpg/prepared_stmt.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from . import connection as _connection
2020

2121

22-
class PreparedStatement(connresource.ConnectionResource):
22+
_Record = typing.TypeVar('_Record', bound='_cprotocol.Record')
23+
24+
25+
class PreparedStatement(connresource.ConnectionResource,
26+
typing.Generic[_Record]):
2327
"""A representation of a prepared statement."""
2428

2529
__slots__ = ('_state', '_query', '_last_status')
@@ -101,7 +105,8 @@ def get_attributes(self) -> typing.Tuple[types.Attribute, ...]:
101105

102106
@connresource.guarded
103107
def cursor(self, *args: typing.Any, prefetch: typing.Optional[int] = None,
104-
timeout: typing.Optional[float] = None) -> cursor.CursorFactory:
108+
timeout: typing.Optional[float] = None) \
109+
-> cursor.CursorFactory[_Record]:
105110
"""Return a *cursor factory* for the prepared statement.
106111
107112
:param args: Query arguments.
@@ -161,7 +166,7 @@ async def explain(self, *args: typing.Any,
161166
@connresource.guarded
162167
async def fetch(self, *args: typing.Any,
163168
timeout: typing.Optional[float] = None) \
164-
-> typing.List['_cprotocol.Record']:
169+
-> typing.List[_Record]:
165170
r"""Execute the statement and return a list of :class:`Record` objects.
166171
167172
:param str query: Query text
@@ -196,7 +201,7 @@ async def fetchval(self, *args: typing.Any, column: int = 0,
196201
@connresource.guarded
197202
async def fetchrow(self, *args: typing.Any,
198203
timeout: typing.Optional[float] = None) \
199-
-> typing.Optional['_cprotocol.Record']:
204+
-> typing.Optional[_Record]:
200205
"""Execute the statement and return the first row.
201206
202207
:param str query: Query text
@@ -213,7 +218,7 @@ async def fetchrow(self, *args: typing.Any,
213218
async def __bind_execute(self, args: typing.Tuple[typing.Any, ...],
214219
limit: int,
215220
timeout: typing.Optional[float]) \
216-
-> typing.List['_cprotocol.Record']:
221+
-> typing.List[_Record]:
217222
protocol = self._connection._protocol
218223
try:
219224
data, status, _ = await protocol.bind_execute(
@@ -227,7 +232,7 @@ async def __bind_execute(self, args: typing.Tuple[typing.Any, ...],
227232
self._state.mark_closed()
228233
raise
229234
self._last_status = status
230-
return data
235+
return data # type: ignore[return-value]
231236

232237
def _check_open(self, meth_name: str) -> None:
233238
if self._state.closed:

0 commit comments

Comments
 (0)