20
20
import binascii
21
21
import datetime
22
22
import math
23
+ import time
23
24
import uuid
24
25
from decimal import Decimal
25
- from typing import Any , Dict , List , NamedTuple , Optional # NOQA for mypy types
26
+ from types import TracebackType
27
+ from typing import Any , Dict , Iterator , List , NamedTuple , Optional , Sequence , Tuple , Type , Union
26
28
27
29
import trino .client
28
30
import trino .exceptions
72
74
logger = trino .logging .get_logger (__name__ )
73
75
74
76
75
- def connect (* args , ** kwargs ) :
77
+ def connect (* args : Any , ** kwargs : Any ) -> trino . dbapi . Connection :
76
78
"""Constructor for creating a connection to the database.
77
79
78
80
See class :py:class:`Connection` for arguments.
@@ -92,28 +94,28 @@ class Connection(object):
92
94
93
95
def __init__ (
94
96
self ,
95
- host ,
96
- port = constants .DEFAULT_PORT ,
97
- user = None ,
98
- source = constants .DEFAULT_SOURCE ,
99
- catalog = constants .DEFAULT_CATALOG ,
100
- schema = constants .DEFAULT_SCHEMA ,
101
- session_properties = None ,
102
- http_headers = None ,
103
- http_scheme = constants .HTTP ,
104
- auth = constants .DEFAULT_AUTH ,
105
- extra_credential = None ,
106
- redirect_handler = None ,
107
- max_attempts = constants .DEFAULT_MAX_ATTEMPTS ,
108
- request_timeout = constants .DEFAULT_REQUEST_TIMEOUT ,
109
- isolation_level = IsolationLevel .AUTOCOMMIT ,
110
- verify = True ,
111
- http_session = None ,
112
- client_tags = None ,
113
- legacy_primitive_types = False ,
114
- roles = None ,
97
+ host : str ,
98
+ port : int = constants .DEFAULT_PORT ,
99
+ user : Optional [ str ] = None ,
100
+ source : str = constants .DEFAULT_SOURCE ,
101
+ catalog : Optional [ str ] = constants .DEFAULT_CATALOG ,
102
+ schema : Optional [ str ] = constants .DEFAULT_SCHEMA ,
103
+ session_properties : Optional [ Dict [ str , str ]] = None ,
104
+ http_headers : Optional [ Dict [ str , str ]] = None ,
105
+ http_scheme : str = constants .HTTP ,
106
+ auth : Optional [ trino . auth . Authentication ] = constants .DEFAULT_AUTH ,
107
+ extra_credential : Optional [ List [ Tuple [ str , str ]]] = None ,
108
+ redirect_handler : Optional [ str ] = None ,
109
+ max_attempts : int = constants .DEFAULT_MAX_ATTEMPTS ,
110
+ request_timeout : float = constants .DEFAULT_REQUEST_TIMEOUT ,
111
+ isolation_level : IsolationLevel = IsolationLevel .AUTOCOMMIT ,
112
+ verify : Union [ bool | str ] = True ,
113
+ http_session : Optional [ trino . client . TrinoRequest . http . Session ] = None ,
114
+ client_tags : Optional [ List [ str ]] = None ,
115
+ legacy_primitive_types : Optional [ bool ] = False ,
116
+ roles : Optional [ Dict [ str , str ]] = None ,
115
117
timezone = None ,
116
- ):
118
+ ) -> None :
117
119
self .host = host
118
120
self .port = port
119
121
self .user = user
@@ -151,50 +153,53 @@ def __init__(
151
153
152
154
self ._isolation_level = isolation_level
153
155
self ._request = None
154
- self ._transaction = None
156
+ self ._transaction : Optional [ Transaction ] = None
155
157
self .legacy_primitive_types = legacy_primitive_types
156
158
157
159
@property
158
- def isolation_level (self ):
160
+ def isolation_level (self ) -> IsolationLevel :
159
161
return self ._isolation_level
160
162
161
163
@property
162
- def transaction (self ):
164
+ def transaction (self ) -> Optional [ Transaction ] :
163
165
return self ._transaction
164
166
165
- def __enter__ (self ):
167
+ def __enter__ (self ) -> object :
166
168
return self
167
169
168
- def __exit__ (self , exc_type , exc_value , traceback ):
170
+ def __exit__ (self ,
171
+ exc_type : Optional [Type [BaseException ]],
172
+ exc_value : Optional [BaseException ],
173
+ traceback : Optional [TracebackType ]) -> None :
169
174
try :
170
175
self .commit ()
171
176
except Exception :
172
177
self .rollback ()
173
178
else :
174
179
self .close ()
175
180
176
- def close (self ):
181
+ def close (self ) -> None :
177
182
# TODO cancel outstanding queries?
178
183
self ._http_session .close ()
179
184
180
- def start_transaction (self ):
185
+ def start_transaction (self ) -> Transaction :
181
186
self ._transaction = Transaction (self ._create_request ())
182
187
self ._transaction .begin ()
183
188
return self ._transaction
184
189
185
- def commit (self ):
190
+ def commit (self ) -> None :
186
191
if self .transaction is None :
187
192
return
188
- self ._transaction .commit ()
193
+ self .transaction .commit ()
189
194
self ._transaction = None
190
195
191
- def rollback (self ):
196
+ def rollback (self ) -> None :
192
197
if self .transaction is None :
193
198
raise RuntimeError ("no transaction was started" )
194
- self ._transaction .rollback ()
199
+ self .transaction .rollback ()
195
200
self ._transaction = None
196
201
197
- def _create_request (self ):
202
+ def _create_request (self ) -> trino . client . TrinoRequest :
198
203
return trino .client .TrinoRequest (
199
204
self .host ,
200
205
self .port ,
@@ -207,7 +212,7 @@ def _create_request(self):
207
212
self .request_timeout ,
208
213
)
209
214
210
- def cursor (self , legacy_primitive_types : bool = None ):
215
+ def cursor (self , legacy_primitive_types : bool = None ) -> 'trino.dbapi.Cursor' :
211
216
"""Return a new :py:class:`Cursor` object using the connection."""
212
217
if self .isolation_level != IsolationLevel .AUTOCOMMIT :
213
218
if self .transaction is None :
@@ -271,7 +276,10 @@ class Cursor(object):
271
276
272
277
"""
273
278
274
- def __init__ (self , connection , request , legacy_primitive_types : bool = False ):
279
+ def __init__ (self ,
280
+ connection : Connection ,
281
+ request : trino .client .TrinoRequest ,
282
+ legacy_primitive_types : bool = False ) -> None :
275
283
if not isinstance (connection , Connection ):
276
284
raise ValueError (
277
285
"connection must be a Connection object: {}" .format (type (connection ))
@@ -280,32 +288,32 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
280
288
self ._request = request
281
289
282
290
self .arraysize = 1
283
- self ._iterator = None
284
- self ._query = None
291
+ self ._iterator : Optional [ Iterator [ Any ]] = None
292
+ self ._query : Optional [ trino . client . TrinoQuery ] = None
285
293
self ._legacy_primitive_types = legacy_primitive_types
286
294
287
- def __iter__ (self ):
295
+ def __iter__ (self ) -> Optional [ Iterator [ Any ]] :
288
296
return self ._iterator
289
297
290
298
@property
291
- def connection (self ):
299
+ def connection (self ) -> Connection :
292
300
return self ._connection
293
301
294
302
@property
295
- def info_uri (self ):
303
+ def info_uri (self ) -> Optional [ str ] :
296
304
if self ._query is not None :
297
305
return self ._query .info_uri
298
306
return None
299
307
300
308
@property
301
- def update_type (self ):
309
+ def update_type (self ) -> Optional [ str ] :
302
310
if self ._query is not None :
303
311
return self ._query .update_type
304
312
return None
305
313
306
314
@property
307
- def description (self ) -> List [ColumnDescription ]:
308
- if self ._query .columns is None :
315
+ def description (self ) -> Optional [ List [Tuple [ Any , ...]] ]:
316
+ if self ._query is None or self . _query .columns is None :
309
317
return None
310
318
311
319
# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
@@ -314,7 +322,7 @@ def description(self) -> List[ColumnDescription]:
314
322
]
315
323
316
324
@property
317
- def rowcount (self ):
325
+ def rowcount (self ) -> int :
318
326
"""Not supported.
319
327
320
328
Trino cannot reliablity determine the number of rows returned by an
@@ -325,27 +333,21 @@ def rowcount(self):
325
333
return - 1
326
334
327
335
@property
328
- def stats (self ):
336
+ def stats (self ) -> Optional [ Dict [ Any , Any ]] :
329
337
if self ._query is not None :
330
338
return self ._query .stats
331
339
return None
332
340
333
341
@property
334
- def query_id (self ) -> Optional [str ]:
335
- if self ._query is not None :
336
- return self ._query .query_id
337
- return None
338
-
339
- @property
340
- def warnings (self ):
342
+ def warnings (self ) -> Optional [List [Dict [Any , Any ]]]:
341
343
if self ._query is not None :
342
344
return self ._query .warnings
343
345
return None
344
346
345
- def setinputsizes (self , sizes ) :
347
+ def setinputsizes (self , sizes : Sequence [ Any ]) -> None :
346
348
raise trino .exceptions .NotSupportedError
347
349
348
- def setoutputsize (self , size , column ) :
350
+ def setoutputsize (self , size : int , column : Optional [ int ]) -> None :
349
351
raise trino .exceptions .NotSupportedError
350
352
351
353
def _prepare_statement (self , statement : str , name : str ) -> None :
@@ -363,13 +365,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
363
365
364
366
def _execute_prepared_statement (
365
367
self ,
366
- statement_name ,
367
- params
368
- ):
368
+ statement_name : str ,
369
+ params : Any
370
+ ) -> trino . client . TrinoQuery :
369
371
sql = 'EXECUTE ' + statement_name + ' USING ' + ',' .join (map (self ._format_prepared_param , params ))
370
372
return trino .client .TrinoQuery (self ._request , sql = sql , legacy_primitive_types = self ._legacy_primitive_types )
371
373
372
- def _format_prepared_param (self , param ) :
374
+ def _format_prepared_param (self , param : Any ) -> str :
373
375
"""
374
376
Formats parameters to be passed in an
375
377
EXECUTE statement.
@@ -451,10 +453,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
451
453
legacy_primitive_types = self ._legacy_primitive_types )
452
454
query .execute ()
453
455
454
- def _generate_unique_statement_name (self ):
456
+ def _generate_unique_statement_name (self ) -> str :
455
457
return 'st_' + uuid .uuid4 ().hex .replace ('-' , '' )
456
458
457
- def execute (self , operation , params = None ):
459
+ def execute (self , operation : str , params : Optional [ Any ] = None ) -> trino . client . TrinoResult :
458
460
if params :
459
461
assert isinstance (params , (list , tuple )), (
460
462
'params must be a list or tuple containing the query '
@@ -484,7 +486,7 @@ def execute(self, operation, params=None):
484
486
self ._iterator = iter (self ._query .execute ())
485
487
return self
486
488
487
- def executemany (self , operation , seq_of_params ) :
489
+ def executemany (self , operation : str , seq_of_params : Any ) -> None :
488
490
"""
489
491
PEP-0249: Prepare a database operation (query or command) and then
490
492
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
@@ -529,7 +531,7 @@ def fetchone(self) -> Optional[List[Any]]:
529
531
except trino .exceptions .HttpError as err :
530
532
raise trino .exceptions .OperationalError (str (err ))
531
533
532
- def fetchmany (self , size = None ) -> List [List [Any ]]:
534
+ def fetchmany (self , size : Optional [ int ] = None ) -> List [List [Any ]]:
533
535
"""
534
536
PEP-0249: Fetch the next set of rows of a query result, returning a
535
537
sequence of sequences (e.g. a list of tuples). An empty sequence is
@@ -584,20 +586,20 @@ def describe(self, sql: str) -> List[DescribeOutput]:
584
586
585
587
return list (map (lambda x : DescribeOutput .from_row (x ), result ))
586
588
587
- def genall (self ):
589
+ def genall (self ) -> trino . client . TrinoResult :
588
590
return self ._query .result
589
591
590
592
def fetchall (self ) -> List [List [Any ]]:
591
593
return list (self .genall ())
592
594
593
- def cancel (self ):
595
+ def cancel (self ) -> None :
594
596
if self ._query is None :
595
597
raise trino .exceptions .OperationalError (
596
598
"Cancel query failed; no running query"
597
599
)
598
600
self ._query .cancel ()
599
601
600
- def close (self ):
602
+ def close (self ) -> None :
601
603
self .cancel ()
602
604
# TODO: Cancel not only the last query executed on this cursor
603
605
# but also any other outstanding queries executed through this cursor.
@@ -610,19 +612,19 @@ def close(self):
610
612
TimestampFromTicks = datetime .datetime .fromtimestamp
611
613
612
614
613
- def TimeFromTicks (ticks ) :
614
- return datetime .time (* datetime .localtime (ticks )[3 :6 ])
615
+ def TimeFromTicks (ticks : int ) -> datetime . time :
616
+ return datetime .time (* time .localtime (ticks )[3 :6 ])
615
617
616
618
617
- def Binary (string ) :
619
+ def Binary (string : str ) -> bytes :
618
620
return string .encode ("utf-8" )
619
621
620
622
621
623
class DBAPITypeObject :
622
- def __init__ (self , * values ):
624
+ def __init__ (self , * values : str ):
623
625
self .values = [v .lower () for v in values ]
624
626
625
- def __eq__ (self , other ) :
627
+ def __eq__ (self , other : object ) -> bool :
626
628
return other .lower () in self .values
627
629
628
630
0 commit comments