55
55
from zoneinfo import ZoneInfo
56
56
57
57
import requests
58
+ from requests import Response
59
+ from requests import Session
60
+ from requests .structures import CaseInsensitiveDict
58
61
from tzlocal import get_localzone_name # type: ignore
59
62
60
63
import trino .logging
61
64
from trino import constants
62
65
from trino import exceptions
63
66
from trino ._version import __version__
67
+ from trino .auth import Authentication
68
+ from trino .exceptions import TrinoExternalError
69
+ from trino .exceptions import TrinoQueryError
70
+ from trino .exceptions import TrinoUserError
64
71
from trino .mapper import RowMapper
65
72
from trino .mapper import RowMapperFactory
66
73
@@ -271,27 +278,27 @@ def __setstate__(self, state):
271
278
self ._object_lock = threading .Lock ()
272
279
273
280
274
- def get_header_values (headers , header ) :
281
+ def get_header_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ str ] :
275
282
return [val .strip () for val in headers [header ].split ("," )]
276
283
277
284
278
- def get_session_property_values (headers , header ) :
285
+ def get_session_property_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
279
286
kvs = get_header_values (headers , header )
280
287
return [
281
288
(k .strip (), urllib .parse .unquote_plus (v .strip ()))
282
289
for k , v in (kv .split ("=" , 1 ) for kv in kvs if kv )
283
290
]
284
291
285
292
286
- def get_prepared_statement_values (headers , header ) :
293
+ def get_prepared_statement_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
287
294
kvs = get_header_values (headers , header )
288
295
return [
289
296
(k .strip (), urllib .parse .unquote_plus (v .strip ()))
290
297
for k , v in (kv .split ("=" , 1 ) for kv in kvs if kv )
291
298
]
292
299
293
300
294
- def get_roles_values (headers , header ) :
301
+ def get_roles_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
295
302
kvs = get_header_values (headers , header )
296
303
return [
297
304
(k .strip (), urllib .parse .unquote_plus (v .strip ()))
@@ -414,9 +421,9 @@ def __init__(
414
421
host : str ,
415
422
port : int ,
416
423
client_session : ClientSession ,
417
- http_session : Any = None ,
418
- http_scheme : str = None ,
419
- auth : Optional [Any ] = constants .DEFAULT_AUTH ,
424
+ http_session : Optional [ Session ] = None ,
425
+ http_scheme : Optional [ str ] = None ,
426
+ auth : Optional [Authentication ] = constants .DEFAULT_AUTH ,
420
427
max_attempts : int = MAX_ATTEMPTS ,
421
428
request_timeout : Union [float , Tuple [float , float ]] = constants .DEFAULT_REQUEST_TIMEOUT ,
422
429
handle_retry = _RetryWithExponentialBackoff (),
@@ -454,16 +461,16 @@ def __init__(
454
461
self .max_attempts = max_attempts
455
462
456
463
@property
457
- def transaction_id (self ):
464
+ def transaction_id (self ) -> Optional [ str ] :
458
465
return self ._client_session .transaction_id
459
466
460
467
@transaction_id .setter
461
- def transaction_id (self , value ) :
468
+ def transaction_id (self , value : Optional [ str ]) -> None :
462
469
self ._client_session .transaction_id = value
463
470
464
471
@property
465
- def http_headers (self ) -> Dict [ str , str ]:
466
- headers = requests . structures . CaseInsensitiveDict ()
472
+ def http_headers (self ) -> CaseInsensitiveDict [ str ]:
473
+ headers : CaseInsensitiveDict [ str ] = CaseInsensitiveDict ()
467
474
468
475
headers [constants .HEADER_CATALOG ] = self ._client_session .catalog
469
476
headers [constants .HEADER_SCHEMA ] = self ._client_session .schema
@@ -525,7 +532,7 @@ def max_attempts(self) -> int:
525
532
return self ._max_attempts
526
533
527
534
@max_attempts .setter
528
- def max_attempts (self , value ) -> None :
535
+ def max_attempts (self , value : int ) -> None :
529
536
self ._max_attempts = value
530
537
if value == 1 : # No retry
531
538
self ._get = self ._http_session .get
@@ -547,7 +554,7 @@ def max_attempts(self, value) -> None:
547
554
self ._post = with_retry (self ._http_session .post )
548
555
self ._delete = with_retry (self ._http_session .delete )
549
556
550
- def get_url (self , path ) -> str :
557
+ def get_url (self , path : str ) -> str :
551
558
return "{protocol}://{host}:{port}{path}" .format (
552
559
protocol = self ._http_scheme , host = self ._host , port = self ._port , path = path
553
560
)
@@ -560,7 +567,7 @@ def statement_url(self) -> str:
560
567
def next_uri (self ) -> Optional [str ]:
561
568
return self ._next_uri
562
569
563
- def post (self , sql : str , additional_http_headers : Optional [Dict [str , Any ]] = None ):
570
+ def post (self , sql : str , additional_http_headers : Optional [Dict [str , Any ]] = None ) -> Response :
564
571
data = sql .encode ("utf-8" )
565
572
# Deep copy of the http_headers dict since they may be modified for this
566
573
# request by the provided additional_http_headers
@@ -578,18 +585,19 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
578
585
)
579
586
return http_response
580
587
581
- def get (self , url : str ):
588
+ def get (self , url : str ) -> Response :
582
589
return self ._get (
583
590
url ,
584
591
headers = self .http_headers ,
585
592
timeout = self ._request_timeout ,
586
593
proxies = PROXIES ,
587
594
)
588
595
589
- def delete (self , url ) :
596
+ def delete (self , url : str ) -> Response :
590
597
return self ._delete (url , timeout = self ._request_timeout , proxies = PROXIES )
591
598
592
- def _process_error (self , error , query_id ):
599
+ @staticmethod
600
+ def _process_error (error , query_id : Optional [str ]) -> Union [TrinoExternalError , TrinoQueryError , TrinoUserError ]:
593
601
error_type = error ["errorType" ]
594
602
if error_type == "EXTERNAL" :
595
603
raise exceptions .TrinoExternalError (error , query_id )
@@ -598,7 +606,8 @@ def _process_error(self, error, query_id):
598
606
599
607
return exceptions .TrinoQueryError (error , query_id )
600
608
601
- def raise_response_error (self , http_response ):
609
+ @staticmethod
610
+ def raise_response_error (http_response : Response ) -> None :
602
611
if http_response .status_code == 502 :
603
612
raise exceptions .Http502Error ("error 502: bad gateway" )
604
613
@@ -615,7 +624,7 @@ def raise_response_error(self, http_response):
615
624
)
616
625
)
617
626
618
- def process (self , http_response ) -> TrinoStatus :
627
+ def process (self , http_response : Response ) -> TrinoStatus :
619
628
if not http_response .ok :
620
629
self .raise_response_error (http_response )
621
630
@@ -682,7 +691,8 @@ def process(self, http_response) -> TrinoStatus:
682
691
columns = response .get ("columns" ),
683
692
)
684
693
685
- def _verify_extra_credential (self , header ):
694
+ @staticmethod
695
+ def _verify_extra_credential (header : Tuple [str , str ]) -> None :
686
696
"""
687
697
Verifies that key has ASCII only and non-whitespace characters.
688
698
"""
0 commit comments