Skip to content

Commit 9b3fb71

Browse files
committed
Use download thread to speed up result retrieval
1 parent ed085c5 commit 9b3fb71

File tree

1 file changed

+43
-14
lines changed

1 file changed

+43
-14
lines changed

trino/client.py

+43-14
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@
3939
import copy
4040
import functools
4141
import os
42+
import queue
4243
import random
4344
import re
4445
import threading
4546
import urllib.parse
47+
from concurrent.futures import ThreadPoolExecutor
4648
import warnings
4749
from datetime import date, datetime, time, timedelta, timezone, tzinfo
4850
from decimal import Decimal
4951
from time import sleep
50-
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
52+
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
5153

5254
import pytz
5355
import requests
@@ -684,6 +686,27 @@ def _verify_extra_credential(self, header):
684686
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")
685687

686688

689+
class ResultDownloader():
690+
def __init__(self):
691+
self.queue: queue.Queue = queue.Queue()
692+
self.executor: Optional[ThreadPoolExecutor] = None
693+
694+
def submit(self, fetch_func: Callable[[], List[Any]]):
695+
assert self.executor is not None
696+
self.executor.submit(self.download_task, fetch_func)
697+
698+
def download_task(self, fetch_func):
699+
self.queue.put(fetch_func())
700+
701+
def __enter__(self):
702+
self.executor = ThreadPoolExecutor(max_workers=1)
703+
return self
704+
705+
def __exit__(self, exc_type, exc_value, exc_traceback):
706+
self.executor.shutdown()
707+
self.executor = None
708+
709+
687710
class TrinoResult(object):
688711
"""
689712
Represent the result of a Trino query as an iterator on rows.
@@ -711,16 +734,21 @@ def rownumber(self) -> int:
711734
return self._rownumber
712735

713736
def __iter__(self):
714-
# A query only transitions to a FINISHED state when the results are fully consumed:
715-
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
716-
while not self._query.finished or self._rows is not None:
717-
next_rows = self._query.fetch() if not self._query.finished else None
718-
for row in self._rows:
719-
self._rownumber += 1
720-
logger.debug("row %s", row)
721-
yield row
737+
with ResultDownloader() as result_downloader:
738+
# A query only transitions to a FINISHED state when the results are fully consumed:
739+
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
740+
result_downloader.submit(self._query.fetch)
741+
while not self._query.finished or self._rows is not None:
742+
next_rows = result_downloader.queue.get() if not self._query.finished else None
743+
if not self._query.finished:
744+
result_downloader.submit(self._query.fetch)
722745

723-
self._rows = next_rows
746+
for row in self._rows:
747+
self._rownumber += 1
748+
logger.debug("row %s", row)
749+
yield row
750+
751+
self._rows = next_rows
724752

725753

726754
class TrinoQuery(object):
@@ -753,7 +781,7 @@ def columns(self):
753781
while not self._columns and not self.finished and not self.cancelled:
754782
# Columns are not returned immediately after query is submitted.
755783
# Continue fetching data until columns information is available and push fetched rows into buffer.
756-
self._result.rows += self.fetch()
784+
self._result.rows += self.map_rows(self.fetch())
757785
return self._columns
758786

759787
@property
@@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
802830

803831
# Execute should block until at least one row is received or query is finished or cancelled
804832
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
805-
self._result.rows += self.fetch()
833+
self._result.rows += self.map_rows(self.fetch())
806834
return self._result
807835

808836
def _update_state(self, status):
@@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]:
822850
logger.debug(status)
823851
if status.next_uri is None:
824852
self._finished = True
853+
return status.rows
825854

855+
def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]:
826856
if not self._row_mapper:
827857
return []
828-
829-
return self._row_mapper.map(status.rows)
858+
return self._row_mapper.map(rows)
830859

831860
def cancel(self) -> None:
832861
"""Cancel the current query"""

0 commit comments

Comments
 (0)