Skip to content

Commit d2f2429

Browse files
committed
mypy checks
1 parent 61b8133 commit d2f2429

File tree

3 files changed

+79
-77
lines changed

3 files changed

+79
-77
lines changed

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ ignore_missing_imports = true
1919
no_implicit_optional = true
2020
warn_unused_ignores = true
2121

22-
[mypy-tests.*,trino.client,trino.dbapi,trino.sqlalchemy.*]
22+
[mypy-tests.*,trino.client,trino.sqlalchemy.*,trino.dbapi]
2323
ignore_errors = true

trino/client.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ class ClientSession(object):
125125

126126
def __init__(
127127
self,
128-
user: str,
129-
catalog: str = None,
130-
schema: str = None,
131-
source: str = None,
128+
user: Optional[str],
129+
catalog: Optional[str] = None,
130+
schema: Optional[str] = None,
131+
source: Optional[str] = None,
132132
properties: Dict[str, str] = None,
133133
headers: Dict[str, str] = None,
134134
transaction_id: str = None,

trino/dbapi.py

+74-72
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
import binascii
2121
import datetime
2222
import math
23+
import time
2324
import uuid
2425
from decimal import Decimal
25-
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
26+
from types import TracebackType
27+
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union
2628

2729
import trino.client
2830
import trino.exceptions
@@ -72,7 +74,7 @@
7274
logger = trino.logging.get_logger(__name__)
7375

7476

75-
def connect(*args, **kwargs):
77+
def connect(*args: Any, **kwargs: Any) -> trino.dbapi.Connection:
7678
"""Constructor for creating a connection to the database.
7779
7880
See class :py:class:`Connection` for arguments.
@@ -92,28 +94,28 @@ class Connection(object):
9294

9395
def __init__(
9496
self,
95-
host,
96-
port=constants.DEFAULT_PORT,
97-
user=None,
98-
source=constants.DEFAULT_SOURCE,
99-
catalog=constants.DEFAULT_CATALOG,
100-
schema=constants.DEFAULT_SCHEMA,
101-
session_properties=None,
102-
http_headers=None,
103-
http_scheme=constants.HTTP,
104-
auth=constants.DEFAULT_AUTH,
105-
extra_credential=None,
106-
redirect_handler=None,
107-
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
108-
request_timeout=constants.DEFAULT_REQUEST_TIMEOUT,
109-
isolation_level=IsolationLevel.AUTOCOMMIT,
110-
verify=True,
111-
http_session=None,
112-
client_tags=None,
113-
legacy_primitive_types=False,
114-
roles=None,
97+
host: str,
98+
port: int = constants.DEFAULT_PORT,
99+
user: Optional[str] = None,
100+
source: str = constants.DEFAULT_SOURCE,
101+
catalog: Optional[str] = constants.DEFAULT_CATALOG,
102+
schema: Optional[str] = constants.DEFAULT_SCHEMA,
103+
session_properties: Optional[Dict[str, str]] = None,
104+
http_headers: Optional[Dict[str, str]] = None,
105+
http_scheme: str = constants.HTTP,
106+
auth: Optional[trino.auth.Authentication] = constants.DEFAULT_AUTH,
107+
extra_credential: Optional[List[Tuple[str, str]]] = None,
108+
redirect_handler: Optional[str] = None,
109+
max_attempts: int = constants.DEFAULT_MAX_ATTEMPTS,
110+
request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT,
111+
isolation_level: IsolationLevel = IsolationLevel.AUTOCOMMIT,
112+
verify: Union[bool | str] = True,
113+
http_session: Optional[trino.client.TrinoRequest.http.Session] = None,
114+
client_tags: Optional[List[str]] = None,
115+
legacy_primitive_types: Optional[bool] = False,
116+
roles: Optional[Dict[str, str]] = None,
115117
timezone=None,
116-
):
118+
) -> None:
117119
self.host = host
118120
self.port = port
119121
self.user = user
@@ -151,50 +153,53 @@ def __init__(
151153

152154
self._isolation_level = isolation_level
153155
self._request = None
154-
self._transaction = None
156+
self._transaction: Optional[Transaction] = None
155157
self.legacy_primitive_types = legacy_primitive_types
156158

157159
@property
158-
def isolation_level(self):
160+
def isolation_level(self) -> IsolationLevel:
159161
return self._isolation_level
160162

161163
@property
162-
def transaction(self):
164+
def transaction(self) -> Optional[Transaction]:
163165
return self._transaction
164166

165-
def __enter__(self):
167+
def __enter__(self) -> object:
166168
return self
167169

168-
def __exit__(self, exc_type, exc_value, traceback):
170+
def __exit__(self,
171+
exc_type: Optional[Type[BaseException]],
172+
exc_value: Optional[BaseException],
173+
traceback: Optional[TracebackType]) -> None:
169174
try:
170175
self.commit()
171176
except Exception:
172177
self.rollback()
173178
else:
174179
self.close()
175180

176-
def close(self):
181+
def close(self) -> None:
177182
# TODO cancel outstanding queries?
178183
self._http_session.close()
179184

180-
def start_transaction(self):
185+
def start_transaction(self) -> Transaction:
181186
self._transaction = Transaction(self._create_request())
182187
self._transaction.begin()
183188
return self._transaction
184189

185-
def commit(self):
190+
def commit(self) -> None:
186191
if self.transaction is None:
187192
return
188-
self._transaction.commit()
193+
self.transaction.commit()
189194
self._transaction = None
190195

191-
def rollback(self):
196+
def rollback(self) -> None:
192197
if self.transaction is None:
193198
raise RuntimeError("no transaction was started")
194-
self._transaction.rollback()
199+
self.transaction.rollback()
195200
self._transaction = None
196201

197-
def _create_request(self):
202+
def _create_request(self) -> trino.client.TrinoRequest:
198203
return trino.client.TrinoRequest(
199204
self.host,
200205
self.port,
@@ -207,7 +212,7 @@ def _create_request(self):
207212
self.request_timeout,
208213
)
209214

210-
def cursor(self, legacy_primitive_types: bool = None):
215+
def cursor(self, legacy_primitive_types: bool = None) -> 'trino.dbapi.Cursor':
211216
"""Return a new :py:class:`Cursor` object using the connection."""
212217
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
213218
if self.transaction is None:
@@ -271,7 +276,10 @@ class Cursor(object):
271276
272277
"""
273278

274-
def __init__(self, connection, request, legacy_primitive_types: bool = False):
279+
def __init__(self,
280+
connection: Connection,
281+
request: trino.client.TrinoRequest,
282+
legacy_primitive_types: bool = False) -> None:
275283
if not isinstance(connection, Connection):
276284
raise ValueError(
277285
"connection must be a Connection object: {}".format(type(connection))
@@ -280,32 +288,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280288
self._request = request
281289

282290
self.arraysize = 1
283-
self._iterator = None
284-
self._query = None
291+
self._iterator: Optional[Iterator[Any]] = None
292+
self._query: Optional[trino.client.TrinoQuery] = None
285293
self._legacy_primitive_types = legacy_primitive_types
286294

287-
def __iter__(self):
295+
def __iter__(self) -> Optional[Iterator[Any]]:
288296
return self._iterator
289297

290298
@property
291-
def connection(self):
299+
def connection(self) -> Connection:
292300
return self._connection
293301

294302
@property
295-
def info_uri(self):
303+
def info_uri(self) -> Optional[str]:
296304
if self._query is not None:
297305
return self._query.info_uri
298306
return None
299307

300308
@property
301-
def update_type(self):
309+
def update_type(self) -> Optional[str]:
302310
if self._query is not None:
303311
return self._query.update_type
304312
return None
305313

306314
@property
307-
def description(self) -> List[ColumnDescription]:
308-
if self._query.columns is None:
315+
def description(self) -> Optional[List[Tuple[Any, ...]]]:
316+
if self._query is None or self._query.columns is None:
309317
return None
310318

311319
# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +322,7 @@ def description(self) -> List[ColumnDescription]:
314322
]
315323

316324
@property
317-
def rowcount(self):
325+
def rowcount(self) -> int:
318326
"""Not supported.
319327
320328
Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +333,21 @@ def rowcount(self):
325333
return -1
326334

327335
@property
328-
def stats(self):
336+
def stats(self) -> Optional[Dict[Any, Any]]:
329337
if self._query is not None:
330338
return self._query.stats
331339
return None
332340

333341
@property
334-
def query_id(self) -> Optional[str]:
335-
if self._query is not None:
336-
return self._query.query_id
337-
return None
338-
339-
@property
340-
def warnings(self):
342+
def warnings(self) -> Optional[List[Dict[Any, Any]]]:
341343
if self._query is not None:
342344
return self._query.warnings
343345
return None
344346

345-
def setinputsizes(self, sizes):
347+
def setinputsizes(self, sizes: Sequence[Any]) -> None:
346348
raise trino.exceptions.NotSupportedError
347349

348-
def setoutputsize(self, size, column):
350+
def setoutputsize(self, size: int, column: Optional[int]) -> None:
349351
raise trino.exceptions.NotSupportedError
350352

351353
def _prepare_statement(self, statement: str, name: str) -> None:
@@ -363,13 +365,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363365

364366
def _execute_prepared_statement(
365367
self,
366-
statement_name,
367-
params
368-
):
368+
statement_name: str,
369+
params: Any
370+
) -> trino.client.TrinoQuery:
369371
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
370372
return trino.client.TrinoQuery(self._request, sql=sql, legacy_primitive_types=self._legacy_primitive_types)
371373

372-
def _format_prepared_param(self, param):
374+
def _format_prepared_param(self, param: Any) -> str:
373375
"""
374376
Formats parameters to be passed in an
375377
EXECUTE statement.
@@ -451,10 +453,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451453
legacy_primitive_types=self._legacy_primitive_types)
452454
query.execute()
453455

454-
def _generate_unique_statement_name(self):
456+
def _generate_unique_statement_name(self) -> str:
455457
return 'st_' + uuid.uuid4().hex.replace('-', '')
456458

457-
def execute(self, operation, params=None):
459+
def execute(self, operation: str, params: Optional[Any] = None) -> trino.client.TrinoResult:
458460
if params:
459461
assert isinstance(params, (list, tuple)), (
460462
'params must be a list or tuple containing the query '
@@ -484,7 +486,7 @@ def execute(self, operation, params=None):
484486
self._iterator = iter(self._query.execute())
485487
return self
486488

487-
def executemany(self, operation, seq_of_params):
489+
def executemany(self, operation: str, seq_of_params: Any) -> None:
488490
"""
489491
PEP-0249: Prepare a database operation (query or command) and then
490492
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529531
except trino.exceptions.HttpError as err:
530532
raise trino.exceptions.OperationalError(str(err))
531533

532-
def fetchmany(self, size=None) -> List[List[Any]]:
534+
def fetchmany(self, size: Optional[int] = None) -> List[List[Any]]:
533535
"""
534536
PEP-0249: Fetch the next set of rows of a query result, returning a
535537
sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -584,20 +586,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584586

585587
return list(map(lambda x: DescribeOutput.from_row(x), result))
586588

587-
def genall(self):
589+
def genall(self) -> trino.client.TrinoResult:
588590
return self._query.result
589591

590592
def fetchall(self) -> List[List[Any]]:
591593
return list(self.genall())
592594

593-
def cancel(self):
595+
def cancel(self) -> None:
594596
if self._query is None:
595597
raise trino.exceptions.OperationalError(
596598
"Cancel query failed; no running query"
597599
)
598600
self._query.cancel()
599601

600-
def close(self):
602+
def close(self) -> None:
601603
self.cancel()
602604
# TODO: Cancel not only the last query executed on this cursor
603605
# but also any other outstanding queries executed through this cursor.
@@ -610,19 +612,19 @@ def close(self):
610612
TimestampFromTicks = datetime.datetime.fromtimestamp
611613

612614

613-
def TimeFromTicks(ticks):
614-
return datetime.time(*datetime.localtime(ticks)[3:6])
615+
def TimeFromTicks(ticks: int) -> datetime.time:
616+
return datetime.time(*time.localtime(ticks)[3:6])
615617

616618

617-
def Binary(string):
619+
def Binary(string: str) -> bytes:
618620
return string.encode("utf-8")
619621

620622

621623
class DBAPITypeObject:
622-
def __init__(self, *values):
624+
def __init__(self, *values: str):
623625
self.values = [v.lower() for v in values]
624626

625-
def __eq__(self, other):
627+
def __eq__(self, other: object) -> bool:
626628
return other.lower() in self.values
627629

628630

0 commit comments

Comments
 (0)