From e764a6a809e20beccfe873e1041612a6a97f6d0c Mon Sep 17 00:00:00 2001 From: shaheer Date: Tue, 18 Jul 2023 23:55:56 +0530 Subject: [PATCH] Add defer_connect config to allow eagerly verifying connection This commit adds a new connection parameter `defer_connect` which can be set to False to force creating a connection when `trino.dbapi.connect` is called. Any connection errors as a result of that get rewrapped into `trino.exceptions.TrinoConnectionError`. By default `defer_connect` is set to `True` so users can explicitly call `trino.dbapi.Connection.connect` to do the connection check. This doesn't end up actually executing a query on the server because after the initial POST request the nextUri in the response is not followed which leaves the query in QUEUED state. This is not documented in the Trino REST API but the server does behave like this today. The benefit is that we can very cheaply verify if the connection is valid without polluting the server's query history or adding queries to queue. Some unit tests today relied on the lazy connection behaviour so they have been adjusted accrodingly. --- tests/unit/sqlalchemy/test_dialect.py | 3 ++- tests/unit/test_dbapi.py | 18 +++++++++-------- trino/dbapi.py | 28 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index c385ab90..cb62fc01 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -252,7 +252,8 @@ def test_get_default_isolation_level(self): assert isolation_level == "AUTOCOMMIT" def test_isolation_level(self): - dbapi_conn = Connection(host="localhost") + # The test only verifies that isolation level is correctly set, no need to attempt actual connection + dbapi_conn = Connection(host="localhost", defer_connect=True) self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE") assert dbapi_conn._isolation_level == IsolationLevel.SERIALIZABLE diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index b56466a2..e0367c86 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -184,7 +184,8 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post conn2.cursor().execute("SELECT 2") conn2.cursor().execute("SELECT 3") - assert len(_post_statement_requests()) == 7 + assert len(_post_statement_requests()) == 9 + # assert only a single token request was sent assert len(_get_token_requests(challenge_id)) == 1 @@ -275,37 +276,38 @@ def test_role_is_set_when_specified(mock_client): def test_hostname_parsing(): - https_server_with_port = Connection("https://mytrinoserver.domain:9999") + # Since this test only verifies URL parsing there is no need to attempt actual connection + https_server_with_port = Connection("https://mytrinoserver.domain:9999", defer_connect=True) assert https_server_with_port.host == "mytrinoserver.domain" assert https_server_with_port.port == 9999 assert https_server_with_port.http_scheme == constants.HTTPS - https_server_without_port = Connection("https://mytrinoserver.domain") + https_server_without_port = Connection("https://mytrinoserver.domain", defer_connect=True) assert https_server_without_port.host == "mytrinoserver.domain" assert https_server_without_port.port == 8080 assert https_server_without_port.http_scheme == constants.HTTPS - http_server_with_port = Connection("http://mytrinoserver.domain:9999") + http_server_with_port = Connection("http://mytrinoserver.domain:9999", defer_connect=True) assert http_server_with_port.host == "mytrinoserver.domain" assert http_server_with_port.port == 9999 assert http_server_with_port.http_scheme == constants.HTTP - http_server_without_port = Connection("http://mytrinoserver.domain") + http_server_without_port = Connection("http://mytrinoserver.domain", defer_connect=True) assert http_server_without_port.host == "mytrinoserver.domain" assert http_server_without_port.port == 8080 assert http_server_without_port.http_scheme == constants.HTTP - http_server_with_path = Connection("http://mytrinoserver.domain/some_path") + http_server_with_path = Connection("http://mytrinoserver.domain/some_path", defer_connect=True) assert http_server_with_path.host == "mytrinoserver.domain/some_path" assert http_server_with_path.port == 8080 assert http_server_with_path.http_scheme == constants.HTTP - only_hostname = Connection("mytrinoserver.domain") + only_hostname = Connection("mytrinoserver.domain", defer_connect=True) assert only_hostname.host == "mytrinoserver.domain" assert only_hostname.port == 8080 assert only_hostname.http_scheme == constants.HTTP - only_hostname_with_path = Connection("mytrinoserver.domain/some_path") + only_hostname_with_path = Connection("mytrinoserver.domain/some_path", defer_connect=True) assert only_hostname_with_path.host == "mytrinoserver.domain/some_path" assert only_hostname_with_path.port == 8080 assert only_hostname_with_path.http_scheme == constants.HTTP diff --git a/trino/dbapi.py b/trino/dbapi.py index 62ce893b..ae1348a3 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -28,6 +28,8 @@ from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types from urllib.parse import urlparse +from requests.exceptions import RequestException + try: from zoneinfo import ZoneInfo except ModuleNotFoundError: @@ -157,6 +159,7 @@ def __init__( legacy_prepared_statements=None, roles=None, timezone=None, + defer_connect=False, ): # Automatically assign http_schema, port based on hostname parsed_host = urlparse(host, allow_fragments=False) @@ -201,6 +204,31 @@ def __init__( self.legacy_primitive_types = legacy_primitive_types self.legacy_prepared_statements = legacy_prepared_statements + if not defer_connect: + self.connect() + + def connect(self) -> None: + connection_test_request = trino.client.TrinoRequest( + self.host, + self.port, + self._client_session, + self._http_session, + self.http_scheme, + self.auth, + self.max_attempts, + self.request_timeout, + verify=self._http_session.verify, + ) + try: + test_response = connection_test_request.post("") + response_content = test_response.content if test_response.content else "" + if not test_response.ok: + raise trino.exceptions.TrinoConnectionError( + "error {}: {}".format(test_response.status_code, response_content)) + + except RequestException as e: + raise trino.exceptions.TrinoConnectionError("connection failed: {}".format(e)) + @property def isolation_level(self): return self._isolation_level