Skip to content

Commit 3a5e182

Browse files
committed
Use download thread to speed up result retrieval
1 parent e4a3f0f commit 3a5e182

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

trino/client.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@
3939
import copy
4040
import functools
4141
import os
42+
import queue
4243
import random
4344
import re
4445
import threading
4546
import urllib.parse
47+
from concurrent.futures import ThreadPoolExecutor
4648
from datetime import date, datetime, time, timedelta, timezone, tzinfo
4749
from decimal import Decimal
4850
from time import sleep
49-
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
51+
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
5052

5153
import pytz
5254
import requests
@@ -666,6 +668,27 @@ def _verify_extra_credential(self, header):
666668
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
667669

668670

671+
class ResultDownloader():
672+
def __init__(self):
673+
self.queue: queue.Queue = queue.Queue()
674+
self.executor: Optional[ThreadPoolExecutor] = None
675+
676+
def submit(self, fetch_func: Callable[[], List[Any]]):
677+
assert self.executor is not None
678+
self.executor.submit(self.download_task, fetch_func)
679+
680+
def download_task(self, fetch_func):
681+
self.queue.put(fetch_func())
682+
683+
def __enter__(self):
684+
self.executor = ThreadPoolExecutor(max_workers=1)
685+
return self
686+
687+
def __exit__(self, exc_type, exc_value, exc_traceback):
688+
self.executor.shutdown()
689+
self.executor = None
690+
691+
669692
class TrinoResult(object):
670693
"""
671694
Represent the result of a Trino query as an iterator on rows.
@@ -693,16 +716,21 @@ def rownumber(self) -> int:
693716
return self._rownumber
694717

695718
def __iter__(self):
696-
# A query only transitions to a FINISHED state when the results are fully consumed:
697-
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
698-
while not self._query.finished or self._rows is not None:
699-
next_rows = self._query.fetch() if not self._query.finished else None
700-
for row in self._rows:
701-
self._rownumber += 1
702-
logger.debug("row %s", row)
703-
yield row
719+
with ResultDownloader() as result_downloader:
720+
# A query only transitions to a FINISHED state when the results are fully consumed:
721+
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
722+
result_downloader.submit(self._query.fetch)
723+
while not self._query.finished or self._rows is not None:
724+
next_rows = result_downloader.queue.get() if not self._query.finished else None
725+
if not self._query.finished:
726+
result_downloader.submit(self._query.fetch)
704727

705-
self._rows = next_rows
728+
for row in self._rows:
729+
self._rownumber += 1
730+
logger.debug("row %s", row)
731+
yield row
732+
733+
self._rows = next_rows
706734

707735

708736
class TrinoQuery(object):
@@ -735,7 +763,7 @@ def columns(self):
735763
while not self._columns and not self.finished and not self.cancelled:
736764
# Columns are not returned immediately after query is submitted.
737765
# Continue fetching data until columns information is available and push fetched rows into buffer.
738-
self._result.rows += self.fetch()
766+
self._result.rows += self.map_rows(self.fetch())
739767
return self._columns
740768

741769
@property
@@ -784,7 +812,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
784812

785813
# Execute should block until at least one row is received or query is finished or cancelled
786814
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
787-
self._result.rows += self.fetch()
815+
self._result.rows += self.map_rows(self.fetch())
788816
return self._result
789817

790818
def _update_state(self, status):
@@ -796,19 +824,20 @@ def _update_state(self, status):
796824
if status.columns:
797825
self._columns = status.columns
798826

799-
def fetch(self) -> List[List[Any]]:
827+
def fetch(self) -> List[Any]:
800828
"""Continue fetching data for the current query_id"""
801829
response = self._request.get(self._request.next_uri)
802830
status = self._request.process(response)
803831
self._update_state(status)
804832
logger.debug(status)
805833
if status.next_uri is None:
806834
self._finished = True
835+
return status.rows
807836

837+
def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]:
808838
if not self._row_mapper:
809839
return []
810-
811-
return self._row_mapper.map(status.rows)
840+
return self._row_mapper.map(rows)
812841

813842
def cancel(self) -> None:
814843
"""Cancel the current query"""

0 commit comments

Comments
 (0)