47
47
from datetime import date , datetime , time , timedelta , timezone , tzinfo
48
48
from decimal import Decimal
49
49
from time import sleep
50
- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
50
+ from typing import (
51
+ Any ,
52
+ Callable ,
53
+ Dict ,
54
+ Generator ,
55
+ Generic ,
56
+ List ,
57
+ Optional ,
58
+ Tuple ,
59
+ Type ,
60
+ TypeVar ,
61
+ Union ,
62
+ )
51
63
52
64
import pytz
53
65
import requests
54
66
from pytz .tzinfo import BaseTzInfo
55
- from tzlocal import get_localzone_name # type: ignore
67
+ from tzlocal import get_localzone_name
56
68
57
69
import trino .logging
58
70
from trino import constants , exceptions
59
71
60
72
try :
61
- from zoneinfo import ZoneInfo # type: ignore
73
+ from zoneinfo import ZoneInfo
62
74
63
75
except ModuleNotFoundError :
64
76
from backports .zoneinfo import ZoneInfo # type: ignore
75
87
else :
76
88
PROXIES = {}
77
89
78
- _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re .compile (r' ^\S[^\s=]*$' )
90
+ _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re .compile (r" ^\S[^\s=]*$" )
79
91
80
92
T = TypeVar ("T" )
81
93
@@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]:
461
473
"{}={}" .format (catalog , urllib .parse .quote (str (role )))
462
474
for catalog , role in self ._client_session .roles .items ()
463
475
)
464
- if self ._client_session .client_tags is not None and len (self ._client_session .client_tags ) > 0 :
465
- headers [constants .HEADER_CLIENT_TAGS ] = "," .join (self ._client_session .client_tags )
476
+ if (
477
+ self ._client_session .client_tags is not None
478
+ and len (self ._client_session .client_tags ) > 0
479
+ ):
480
+ headers [constants .HEADER_CLIENT_TAGS ] = "," .join (
481
+ self ._client_session .client_tags
482
+ )
466
483
467
484
headers [constants .HEADER_SESSION ] = "," .join (
468
485
# ``name`` must not contain ``=``
@@ -486,18 +503,23 @@ def http_headers(self) -> Dict[str, str]:
486
503
transaction_id = self ._client_session .transaction_id
487
504
headers [constants .HEADER_TRANSACTION ] = transaction_id
488
505
489
- if self ._client_session .extra_credential is not None and \
490
- len (self ._client_session .extra_credential ) > 0 :
506
+ if (
507
+ self ._client_session .extra_credential is not None
508
+ and len (self ._client_session .extra_credential ) > 0
509
+ ):
491
510
492
511
for tup in self ._client_session .extra_credential :
493
512
self ._verify_extra_credential (tup )
494
513
495
514
# HTTP 1.1 section 4.2 combine multiple extra credentials into a
496
515
# comma-separated value
497
516
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
498
- headers [constants .HEADER_EXTRA_CREDENTIAL ] = \
499
- ", " .join (
500
- [f"{ tup [0 ]} ={ urllib .parse .quote_plus (tup [1 ])} " for tup in self ._client_session .extra_credential ])
517
+ headers [constants .HEADER_EXTRA_CREDENTIAL ] = ", " .join (
518
+ [
519
+ f"{ tup [0 ]} ={ urllib .parse .quote_plus (tup [1 ])} "
520
+ for tup in self ._client_session .extra_credential
521
+ ]
522
+ )
501
523
502
524
return headers
503
525
@@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
562
584
while http_response is not None and http_response .is_redirect :
563
585
location = http_response .headers ["Location" ]
564
586
url = self ._redirect_handler .handle (location )
565
- logger .info ("redirect %s from %s to %s" , http_response .status_code , location , url )
587
+ logger .info (
588
+ "redirect %s from %s to %s" ,
589
+ http_response .status_code ,
590
+ location ,
591
+ url ,
592
+ )
566
593
http_response = self ._post (
567
594
url ,
568
595
data = data ,
@@ -606,7 +633,7 @@ def raise_response_error(self, http_response):
606
633
raise exceptions .HttpError (
607
634
"error {}{}" .format (
608
635
http_response .status_code ,
609
- ": {}" .format (http_response .content ) if http_response .content else "" ,
636
+ ": {}" .format (repr ( http_response .content ) ) if http_response .content else "" ,
610
637
)
611
638
)
612
639
@@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus:
633
660
self ._client_session .properties [key ] = value
634
661
635
662
if constants .HEADER_SET_CATALOG in http_response .headers :
636
- self ._client_session .catalog = http_response .headers [constants .HEADER_SET_CATALOG ]
663
+ self ._client_session .catalog = http_response .headers [
664
+ constants .HEADER_SET_CATALOG
665
+ ]
637
666
638
667
if constants .HEADER_SET_SCHEMA in http_response .headers :
639
- self ._client_session .schema = http_response .headers [constants .HEADER_SET_SCHEMA ]
668
+ self ._client_session .schema = http_response .headers [
669
+ constants .HEADER_SET_SCHEMA
670
+ ]
640
671
641
672
if constants .HEADER_SET_ROLE in http_response .headers :
642
673
for key , value in get_roles_values (
643
- http_response .headers , constants .HEADER_SET_ROLE
674
+ http_response .headers , constants .HEADER_SET_ROLE
644
675
):
645
676
self ._client_session .roles [key ] = value
646
677
@@ -676,12 +707,16 @@ def _verify_extra_credential(self, header):
676
707
key = header [0 ]
677
708
678
709
if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX .match (key ):
679
- raise ValueError (f"whitespace or '=' are disallowed in extra credential '{ key } '" )
710
+ raise ValueError (
711
+ f"whitespace or '=' are disallowed in extra credential '{ key } '"
712
+ )
680
713
681
714
try :
682
- key .encode ().decode (' ascii' )
715
+ key .encode ().decode (" ascii" )
683
716
except UnicodeDecodeError :
684
- raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
717
+ raise ValueError (
718
+ f"only ASCII characters are allowed in extra credential '{ key } '"
719
+ )
685
720
686
721
687
722
class TrinoResult (object ):
@@ -847,7 +882,10 @@ def cancel(self) -> None:
847
882
848
883
def is_finished (self ) -> bool :
849
884
import warnings
850
- warnings .warn ("is_finished is deprecated, use finished instead" , DeprecationWarning )
885
+
886
+ warnings .warn (
887
+ "is_finished is deprecated, use finished instead" , DeprecationWarning
888
+ )
851
889
return self .finished
852
890
853
891
@property
@@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]):
910
948
def map (self , value ) -> Optional [float ]:
911
949
if value is None :
912
950
return None
913
- if value == ' Infinity' :
951
+ if value == " Infinity" :
914
952
return float ("inf" )
915
- if value == ' -Infinity' :
953
+ if value == " -Infinity" :
916
954
return float ("-inf" )
917
- if value == ' NaN' :
955
+ if value == " NaN" :
918
956
return float ("nan" )
919
957
return float (value )
920
958
@@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]):
1119
1157
def map (self , values : List [Any ]) -> Optional [Tuple [Optional [Any ], ...]]:
1120
1158
if values is None :
1121
1159
return None
1122
- return tuple (self .mappers [index ].map (value ) for index , value in enumerate (values ))
1160
+ return tuple (
1161
+ self .mappers [index ].map (value ) for index , value in enumerate (values )
1162
+ )
1123
1163
1124
1164
1125
1165
class MapValueMapper (ValueMapper [Dict [Any , Optional [Any ]]]):
@@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
1131
1171
if values is None :
1132
1172
return None
1133
1173
return {
1134
- self .key_mapper .map (key ): self .value_mapper .map (value ) for key , value in values .items ()
1174
+ self .key_mapper .map (key ): self .value_mapper .map (value )
1175
+ for key , value in values .items ()
1135
1176
}
1136
1177
1137
1178
@@ -1151,6 +1192,7 @@ class RowMapperFactory:
1151
1192
lambda functions (one for each column) which will process a data value
1152
1193
and returns a RowMapper instance which will process rows of data
1153
1194
"""
1195
+
1154
1196
NO_OP_ROW_MAPPER = NoOpRowMapper ()
1155
1197
1156
1198
def create (self , columns , legacy_primitive_types ):
@@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types):
1163
1205
def _create_value_mapper (self , column ) -> ValueMapper :
1164
1206
col_type = column ['rawType' ]
1165
1207
1166
- if col_type == ' array' :
1167
- value_mapper = self ._create_value_mapper (column [' arguments' ][0 ][' value' ])
1208
+ if col_type == " array" :
1209
+ value_mapper = self ._create_value_mapper (column [" arguments" ][0 ][" value" ])
1168
1210
return ArrayValueMapper (value_mapper )
1169
- elif col_type == 'row' :
1170
- mappers = [self ._create_value_mapper (arg ['value' ]['typeSignature' ]) for arg in column ['arguments' ]]
1211
+ elif col_type == "row" :
1212
+ mappers = [
1213
+ self ._create_value_mapper (arg ["value" ]["typeSignature" ])
1214
+ for arg in column ["arguments" ]
1215
+ ]
1171
1216
return RowValueMapper (mappers )
1172
- elif col_type == ' map' :
1173
- key_mapper = self ._create_value_mapper (column [' arguments' ][0 ][' value' ])
1174
- value_mapper = self ._create_value_mapper (column [' arguments' ][1 ][' value' ])
1217
+ elif col_type == " map" :
1218
+ key_mapper = self ._create_value_mapper (column [" arguments" ][0 ][" value" ])
1219
+ value_mapper = self ._create_value_mapper (column [" arguments" ][1 ][" value" ])
1175
1220
return MapValueMapper (key_mapper , value_mapper )
1176
- elif col_type .startswith (' decimal' ):
1221
+ elif col_type .startswith (" decimal" ):
1177
1222
return DecimalValueMapper ()
1178
- elif col_type .startswith (' double' ) or col_type .startswith (' real' ):
1223
+ elif col_type .startswith (" double" ) or col_type .startswith (" real" ):
1179
1224
return DoubleValueMapper ()
1180
1225
elif col_type .startswith ('timestamp' ) and 'with time zone' in col_type :
1181
1226
return TimestampWithTimeZoneValueMapper (self ._get_precision (column ))
0 commit comments