Skip to content

Commit 649d48d

Browse files
arturdryomovhashhar
authored andcommitted
Add missing type annotations to TrinoRequest class
1 parent 9d01f76 commit 649d48d

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

trino/client.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@
5555
from zoneinfo import ZoneInfo
5656

5757
import requests
58+
from requests import Response
59+
from requests import Session
60+
from requests.structures import CaseInsensitiveDict
5861
from tzlocal import get_localzone_name # type: ignore
5962

6063
import trino.logging
6164
from trino import constants
6265
from trino import exceptions
6366
from trino._version import __version__
67+
from trino.auth import Authentication
68+
from trino.exceptions import TrinoExternalError
69+
from trino.exceptions import TrinoQueryError
70+
from trino.exceptions import TrinoUserError
6471
from trino.mapper import RowMapper
6572
from trino.mapper import RowMapperFactory
6673

@@ -271,27 +278,27 @@ def __setstate__(self, state):
271278
self._object_lock = threading.Lock()
272279

273280

274-
def get_header_values(headers, header):
281+
def get_header_values(headers: CaseInsensitiveDict[str], header: str) -> List[str]:
275282
return [val.strip() for val in headers[header].split(",")]
276283

277284

278-
def get_session_property_values(headers, header):
285+
def get_session_property_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
279286
kvs = get_header_values(headers, header)
280287
return [
281288
(k.strip(), urllib.parse.unquote_plus(v.strip()))
282289
for k, v in (kv.split("=", 1) for kv in kvs if kv)
283290
]
284291

285292

286-
def get_prepared_statement_values(headers, header):
293+
def get_prepared_statement_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
287294
kvs = get_header_values(headers, header)
288295
return [
289296
(k.strip(), urllib.parse.unquote_plus(v.strip()))
290297
for k, v in (kv.split("=", 1) for kv in kvs if kv)
291298
]
292299

293300

294-
def get_roles_values(headers, header):
301+
def get_roles_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
295302
kvs = get_header_values(headers, header)
296303
return [
297304
(k.strip(), urllib.parse.unquote_plus(v.strip()))
@@ -414,9 +421,9 @@ def __init__(
414421
host: str,
415422
port: int,
416423
client_session: ClientSession,
417-
http_session: Any = None,
418-
http_scheme: str = None,
419-
auth: Optional[Any] = constants.DEFAULT_AUTH,
424+
http_session: Optional[Session] = None,
425+
http_scheme: Optional[str] = None,
426+
auth: Optional[Authentication] = constants.DEFAULT_AUTH,
420427
max_attempts: int = MAX_ATTEMPTS,
421428
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
422429
handle_retry=_RetryWithExponentialBackoff(),
@@ -454,16 +461,16 @@ def __init__(
454461
self.max_attempts = max_attempts
455462

456463
@property
457-
def transaction_id(self):
464+
def transaction_id(self) -> Optional[str]:
458465
return self._client_session.transaction_id
459466

460467
@transaction_id.setter
461-
def transaction_id(self, value):
468+
def transaction_id(self, value: Optional[str]) -> None:
462469
self._client_session.transaction_id = value
463470

464471
@property
465-
def http_headers(self) -> Dict[str, str]:
466-
headers = requests.structures.CaseInsensitiveDict()
472+
def http_headers(self) -> CaseInsensitiveDict[str]:
473+
headers: CaseInsensitiveDict[str] = CaseInsensitiveDict()
467474

468475
headers[constants.HEADER_CATALOG] = self._client_session.catalog
469476
headers[constants.HEADER_SCHEMA] = self._client_session.schema
@@ -525,7 +532,7 @@ def max_attempts(self) -> int:
525532
return self._max_attempts
526533

527534
@max_attempts.setter
528-
def max_attempts(self, value) -> None:
535+
def max_attempts(self, value: int) -> None:
529536
self._max_attempts = value
530537
if value == 1: # No retry
531538
self._get = self._http_session.get
@@ -547,7 +554,7 @@ def max_attempts(self, value) -> None:
547554
self._post = with_retry(self._http_session.post)
548555
self._delete = with_retry(self._http_session.delete)
549556

550-
def get_url(self, path) -> str:
557+
def get_url(self, path: str) -> str:
551558
return "{protocol}://{host}:{port}{path}".format(
552559
protocol=self._http_scheme, host=self._host, port=self._port, path=path
553560
)
@@ -560,7 +567,7 @@ def statement_url(self) -> str:
560567
def next_uri(self) -> Optional[str]:
561568
return self._next_uri
562569

563-
def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None):
570+
def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None) -> Response:
564571
data = sql.encode("utf-8")
565572
# Deep copy of the http_headers dict since they may be modified for this
566573
# request by the provided additional_http_headers
@@ -578,18 +585,19 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
578585
)
579586
return http_response
580587

581-
def get(self, url: str):
588+
def get(self, url: str) -> Response:
582589
return self._get(
583590
url,
584591
headers=self.http_headers,
585592
timeout=self._request_timeout,
586593
proxies=PROXIES,
587594
)
588595

589-
def delete(self, url):
596+
def delete(self, url: str) -> Response:
590597
return self._delete(url, timeout=self._request_timeout, proxies=PROXIES)
591598

592-
def _process_error(self, error, query_id):
599+
@staticmethod
600+
def _process_error(error, query_id: Optional[str]) -> Union[TrinoExternalError, TrinoQueryError, TrinoUserError]:
593601
error_type = error["errorType"]
594602
if error_type == "EXTERNAL":
595603
raise exceptions.TrinoExternalError(error, query_id)
@@ -598,7 +606,8 @@ def _process_error(self, error, query_id):
598606

599607
return exceptions.TrinoQueryError(error, query_id)
600608

601-
def raise_response_error(self, http_response):
609+
@staticmethod
610+
def raise_response_error(http_response: Response) -> None:
602611
if http_response.status_code == 502:
603612
raise exceptions.Http502Error("error 502: bad gateway")
604613

@@ -615,7 +624,7 @@ def raise_response_error(self, http_response):
615624
)
616625
)
617626

618-
def process(self, http_response) -> TrinoStatus:
627+
def process(self, http_response: Response) -> TrinoStatus:
619628
if not http_response.ok:
620629
self.raise_response_error(http_response)
621630

@@ -682,7 +691,8 @@ def process(self, http_response) -> TrinoStatus:
682691
columns=response.get("columns"),
683692
)
684693

685-
def _verify_extra_credential(self, header):
694+
@staticmethod
695+
def _verify_extra_credential(header: Tuple[str, str]) -> None:
686696
"""
687697
Verifies that key has ASCII only and non-whitespace characters.
688698
"""

0 commit comments

Comments
 (0)