Skip to content

Commit e764a6a

Browse files
Add defer_connect config to allow eagerly verifying connection
This commit adds a new connection parameter `defer_connect` which can be set to False to force creating a connection when `trino.dbapi.connect` is called. Any connection errors as a result of that get rewrapped into `trino.exceptions.TrinoConnectionError`. By default `defer_connect` is set to `True` so users can explicitly call `trino.dbapi.Connection.connect` to do the connection check. This doesn't end up actually executing a query on the server because after the initial POST request the nextUri in the response is not followed which leaves the query in QUEUED state. This is not documented in the Trino REST API but the server does behave like this today. The benefit is that we can very cheaply verify if the connection is valid without polluting the server's query history or adding queries to queue. Some unit tests today relied on the lazy connection behaviour so they have been adjusted accrodingly.
1 parent f98a608 commit e764a6a

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

tests/unit/sqlalchemy/test_dialect.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ def test_get_default_isolation_level(self):
252252
assert isolation_level == "AUTOCOMMIT"
253253

254254
def test_isolation_level(self):
255-
dbapi_conn = Connection(host="localhost")
255+
# The test only verifies that isolation level is correctly set, no need to attempt actual connection
256+
dbapi_conn = Connection(host="localhost", defer_connect=True)
256257

257258
self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE")
258259
assert dbapi_conn._isolation_level == IsolationLevel.SERIALIZABLE

tests/unit/test_dbapi.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
184184
conn2.cursor().execute("SELECT 2")
185185
conn2.cursor().execute("SELECT 3")
186186

187-
assert len(_post_statement_requests()) == 7
187+
assert len(_post_statement_requests()) == 9
188+
# assert only a single token request was sent
188189
assert len(_get_token_requests(challenge_id)) == 1
189190

190191

@@ -275,37 +276,38 @@ def test_role_is_set_when_specified(mock_client):
275276

276277

277278
def test_hostname_parsing():
278-
https_server_with_port = Connection("https://mytrinoserver.domain:9999")
279+
# Since this test only verifies URL parsing there is no need to attempt actual connection
280+
https_server_with_port = Connection("https://mytrinoserver.domain:9999", defer_connect=True)
279281
assert https_server_with_port.host == "mytrinoserver.domain"
280282
assert https_server_with_port.port == 9999
281283
assert https_server_with_port.http_scheme == constants.HTTPS
282284

283-
https_server_without_port = Connection("https://mytrinoserver.domain")
285+
https_server_without_port = Connection("https://mytrinoserver.domain", defer_connect=True)
284286
assert https_server_without_port.host == "mytrinoserver.domain"
285287
assert https_server_without_port.port == 8080
286288
assert https_server_without_port.http_scheme == constants.HTTPS
287289

288-
http_server_with_port = Connection("http://mytrinoserver.domain:9999")
290+
http_server_with_port = Connection("http://mytrinoserver.domain:9999", defer_connect=True)
289291
assert http_server_with_port.host == "mytrinoserver.domain"
290292
assert http_server_with_port.port == 9999
291293
assert http_server_with_port.http_scheme == constants.HTTP
292294

293-
http_server_without_port = Connection("http://mytrinoserver.domain")
295+
http_server_without_port = Connection("http://mytrinoserver.domain", defer_connect=True)
294296
assert http_server_without_port.host == "mytrinoserver.domain"
295297
assert http_server_without_port.port == 8080
296298
assert http_server_without_port.http_scheme == constants.HTTP
297299

298-
http_server_with_path = Connection("http://mytrinoserver.domain/some_path")
300+
http_server_with_path = Connection("http://mytrinoserver.domain/some_path", defer_connect=True)
299301
assert http_server_with_path.host == "mytrinoserver.domain/some_path"
300302
assert http_server_with_path.port == 8080
301303
assert http_server_with_path.http_scheme == constants.HTTP
302304

303-
only_hostname = Connection("mytrinoserver.domain")
305+
only_hostname = Connection("mytrinoserver.domain", defer_connect=True)
304306
assert only_hostname.host == "mytrinoserver.domain"
305307
assert only_hostname.port == 8080
306308
assert only_hostname.http_scheme == constants.HTTP
307309

308-
only_hostname_with_path = Connection("mytrinoserver.domain/some_path")
310+
only_hostname_with_path = Connection("mytrinoserver.domain/some_path", defer_connect=True)
309311
assert only_hostname_with_path.host == "mytrinoserver.domain/some_path"
310312
assert only_hostname_with_path.port == 8080
311313
assert only_hostname_with_path.http_scheme == constants.HTTP

trino/dbapi.py

+28
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
2929
from urllib.parse import urlparse
3030

31+
from requests.exceptions import RequestException
32+
3133
try:
3234
from zoneinfo import ZoneInfo
3335
except ModuleNotFoundError:
@@ -157,6 +159,7 @@ def __init__(
157159
legacy_prepared_statements=None,
158160
roles=None,
159161
timezone=None,
162+
defer_connect=False,
160163
):
161164
# Automatically assign http_schema, port based on hostname
162165
parsed_host = urlparse(host, allow_fragments=False)
@@ -201,6 +204,31 @@ def __init__(
201204
self.legacy_primitive_types = legacy_primitive_types
202205
self.legacy_prepared_statements = legacy_prepared_statements
203206

207+
if not defer_connect:
208+
self.connect()
209+
210+
def connect(self) -> None:
211+
connection_test_request = trino.client.TrinoRequest(
212+
self.host,
213+
self.port,
214+
self._client_session,
215+
self._http_session,
216+
self.http_scheme,
217+
self.auth,
218+
self.max_attempts,
219+
self.request_timeout,
220+
verify=self._http_session.verify,
221+
)
222+
try:
223+
test_response = connection_test_request.post("<not-going-to-be-executed>")
224+
response_content = test_response.content if test_response.content else ""
225+
if not test_response.ok:
226+
raise trino.exceptions.TrinoConnectionError(
227+
"error {}: {}".format(test_response.status_code, response_content))
228+
229+
except RequestException as e:
230+
raise trino.exceptions.TrinoConnectionError("connection failed: {}".format(e))
231+
204232
@property
205233
def isolation_level(self):
206234
return self._isolation_level

0 commit comments

Comments
 (0)