Skip to content

Commit 435a30f

Browse files
committed
Reformatting
1 parent a45c000 commit 435a30f

File tree

1 file changed

+79
-34
lines changed

1 file changed

+79
-34
lines changed

trino/client.py

+79-34
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,30 @@
4747
from datetime import date, datetime, time, timedelta, timezone, tzinfo
4848
from decimal import Decimal
4949
from time import sleep
50-
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
50+
from typing import (
51+
Any,
52+
Callable,
53+
Dict,
54+
Generator,
55+
Generic,
56+
List,
57+
Optional,
58+
Tuple,
59+
Type,
60+
TypeVar,
61+
Union,
62+
)
5163

5264
import pytz
5365
import requests
5466
from pytz.tzinfo import BaseTzInfo
55-
from tzlocal import get_localzone_name # type: ignore
67+
from tzlocal import get_localzone_name
5668

5769
import trino.logging
5870
from trino import constants, exceptions
5971

6072
try:
61-
from zoneinfo import ZoneInfo # type: ignore
73+
from zoneinfo import ZoneInfo
6274

6375
except ModuleNotFoundError:
6476
from backports.zoneinfo import ZoneInfo # type: ignore
@@ -75,7 +87,7 @@
7587
else:
7688
PROXIES = {}
7789

78-
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
90+
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$")
7991

8092
T = TypeVar("T")
8193

@@ -461,8 +473,13 @@ def http_headers(self) -> Dict[str, str]:
461473
"{}={}".format(catalog, urllib.parse.quote(str(role)))
462474
for catalog, role in self._client_session.roles.items()
463475
)
464-
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
465-
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
476+
if (
477+
self._client_session.client_tags is not None
478+
and len(self._client_session.client_tags) > 0
479+
):
480+
headers[constants.HEADER_CLIENT_TAGS] = ",".join(
481+
self._client_session.client_tags
482+
)
466483

467484
headers[constants.HEADER_SESSION] = ",".join(
468485
# ``name`` must not contain ``=``
@@ -486,18 +503,23 @@ def http_headers(self) -> Dict[str, str]:
486503
transaction_id = self._client_session.transaction_id
487504
headers[constants.HEADER_TRANSACTION] = transaction_id
488505

489-
if self._client_session.extra_credential is not None and \
490-
len(self._client_session.extra_credential) > 0:
506+
if (
507+
self._client_session.extra_credential is not None
508+
and len(self._client_session.extra_credential) > 0
509+
):
491510

492511
for tup in self._client_session.extra_credential:
493512
self._verify_extra_credential(tup)
494513

495514
# HTTP 1.1 section 4.2 combine multiple extra credentials into a
496515
# comma-separated value
497516
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
498-
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
499-
", ".join(
500-
[f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential])
517+
headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join(
518+
[
519+
f"{tup[0]}={urllib.parse.quote_plus(tup[1])}"
520+
for tup in self._client_session.extra_credential
521+
]
522+
)
501523

502524
return headers
503525

@@ -562,7 +584,12 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
562584
while http_response is not None and http_response.is_redirect:
563585
location = http_response.headers["Location"]
564586
url = self._redirect_handler.handle(location)
565-
logger.info("redirect %s from %s to %s", http_response.status_code, location, url)
587+
logger.info(
588+
"redirect %s from %s to %s",
589+
http_response.status_code,
590+
location,
591+
url,
592+
)
566593
http_response = self._post(
567594
url,
568595
data=data,
@@ -606,7 +633,7 @@ def raise_response_error(self, http_response):
606633
raise exceptions.HttpError(
607634
"error {}{}".format(
608635
http_response.status_code,
609-
": {}".format(http_response.content) if http_response.content else "",
636+
": {}".format(repr(http_response.content)) if http_response.content else "",
610637
)
611638
)
612639

@@ -633,14 +660,18 @@ def process(self, http_response) -> TrinoStatus:
633660
self._client_session.properties[key] = value
634661

635662
if constants.HEADER_SET_CATALOG in http_response.headers:
636-
self._client_session.catalog = http_response.headers[constants.HEADER_SET_CATALOG]
663+
self._client_session.catalog = http_response.headers[
664+
constants.HEADER_SET_CATALOG
665+
]
637666

638667
if constants.HEADER_SET_SCHEMA in http_response.headers:
639-
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]
668+
self._client_session.schema = http_response.headers[
669+
constants.HEADER_SET_SCHEMA
670+
]
640671

641672
if constants.HEADER_SET_ROLE in http_response.headers:
642673
for key, value in get_roles_values(
643-
http_response.headers, constants.HEADER_SET_ROLE
674+
http_response.headers, constants.HEADER_SET_ROLE
644675
):
645676
self._client_session.roles[key] = value
646677

@@ -676,12 +707,16 @@ def _verify_extra_credential(self, header):
676707
key = header[0]
677708

678709
if not _HEADER_EXTRA_CREDENTIAL_KEY_REGEX.match(key):
679-
raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'")
710+
raise ValueError(
711+
f"whitespace or '=' are disallowed in extra credential '{key}'"
712+
)
680713

681714
try:
682-
key.encode().decode('ascii')
715+
key.encode().decode("ascii")
683716
except UnicodeDecodeError:
684-
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
717+
raise ValueError(
718+
f"only ASCII characters are allowed in extra credential '{key}'"
719+
)
685720

686721

687722
class TrinoResult(object):
@@ -847,7 +882,10 @@ def cancel(self) -> None:
847882

848883
def is_finished(self) -> bool:
849884
import warnings
850-
warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)
885+
886+
warnings.warn(
887+
"is_finished is deprecated, use finished instead", DeprecationWarning
888+
)
851889
return self.finished
852890

853891
@property
@@ -910,11 +948,11 @@ class DoubleValueMapper(ValueMapper[float]):
910948
def map(self, value) -> Optional[float]:
911949
if value is None:
912950
return None
913-
if value == 'Infinity':
951+
if value == "Infinity":
914952
return float("inf")
915-
if value == '-Infinity':
953+
if value == "-Infinity":
916954
return float("-inf")
917-
if value == 'NaN':
955+
if value == "NaN":
918956
return float("nan")
919957
return float(value)
920958

@@ -1119,7 +1157,9 @@ def __init__(self, mappers: List[ValueMapper[Any]]):
11191157
def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
11201158
if values is None:
11211159
return None
1122-
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))
1160+
return tuple(
1161+
self.mappers[index].map(value) for index, value in enumerate(values)
1162+
)
11231163

11241164

11251165
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
@@ -1131,7 +1171,8 @@ def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
11311171
if values is None:
11321172
return None
11331173
return {
1134-
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
1174+
self.key_mapper.map(key): self.value_mapper.map(value)
1175+
for key, value in values.items()
11351176
}
11361177

11371178

@@ -1151,6 +1192,7 @@ class RowMapperFactory:
11511192
lambda functions (one for each column) which will process a data value
11521193
and returns a RowMapper instance which will process rows of data
11531194
"""
1195+
11541196
NO_OP_ROW_MAPPER = NoOpRowMapper()
11551197

11561198
def create(self, columns, legacy_primitive_types):
@@ -1163,19 +1205,22 @@ def create(self, columns, legacy_primitive_types):
11631205
def _create_value_mapper(self, column) -> ValueMapper:
11641206
col_type = column['rawType']
11651207

1166-
if col_type == 'array':
1167-
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
1208+
if col_type == "array":
1209+
value_mapper = self._create_value_mapper(column["arguments"][0]["value"])
11681210
return ArrayValueMapper(value_mapper)
1169-
elif col_type == 'row':
1170-
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
1211+
elif col_type == "row":
1212+
mappers = [
1213+
self._create_value_mapper(arg["value"]["typeSignature"])
1214+
for arg in column["arguments"]
1215+
]
11711216
return RowValueMapper(mappers)
1172-
elif col_type == 'map':
1173-
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
1174-
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
1217+
elif col_type == "map":
1218+
key_mapper = self._create_value_mapper(column["arguments"][0]["value"])
1219+
value_mapper = self._create_value_mapper(column["arguments"][1]["value"])
11751220
return MapValueMapper(key_mapper, value_mapper)
1176-
elif col_type.startswith('decimal'):
1221+
elif col_type.startswith("decimal"):
11771222
return DecimalValueMapper()
1178-
elif col_type.startswith('double') or col_type.startswith('real'):
1223+
elif col_type.startswith("double") or col_type.startswith("real"):
11791224
return DoubleValueMapper()
11801225
elif col_type.startswith('timestamp') and 'with time zone' in col_type:
11811226
return TimestampWithTimeZoneValueMapper(self._get_precision(column))

0 commit comments

Comments
 (0)