39
39
import copy
40
40
import functools
41
41
import os
42
+ import queue
42
43
import random
43
44
import re
44
45
import threading
45
46
import urllib .parse
46
47
import warnings
48
+ from concurrent .futures import ThreadPoolExecutor
47
49
from datetime import date , datetime , time , timedelta , timezone , tzinfo
48
50
from decimal import Decimal
49
51
from time import sleep
50
- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
52
+ from typing import Any , Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
51
53
52
54
import pytz
53
55
import requests
@@ -684,6 +686,27 @@ def _verify_extra_credential(self, header):
684
686
raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
685
687
686
688
689
+ class ResultDownloader ():
690
+ def __init__ (self ):
691
+ self .queue : queue .Queue = queue .Queue ()
692
+ self .executor : Optional [ThreadPoolExecutor ] = None
693
+
694
+ def submit (self , fetch_func : Callable [[], List [Any ]]):
695
+ assert self .executor is not None
696
+ self .executor .submit (self .download_task , fetch_func )
697
+
698
+ def download_task (self , fetch_func ):
699
+ self .queue .put (fetch_func ())
700
+
701
+ def __enter__ (self ):
702
+ self .executor = ThreadPoolExecutor (max_workers = 1 )
703
+ return self
704
+
705
+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
706
+ self .executor .shutdown ()
707
+ self .executor = None
708
+
709
+
687
710
class TrinoResult (object ):
688
711
"""
689
712
Represent the result of a Trino query as an iterator on rows.
@@ -711,16 +734,21 @@ def rownumber(self) -> int:
711
734
return self ._rownumber
712
735
713
736
def __iter__ (self ):
714
- # A query only transitions to a FINISHED state when the results are fully consumed :
715
- # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
716
- while not self . _query . finished or self . _rows is not None :
717
- next_rows = self . _query . fetch () if not self ._query .finished else None
718
- for row in self ._rows :
719
- self ._rownumber += 1
720
- logger . debug ( "row %s" , row )
721
- yield row
737
+ with ResultDownloader () as result_downloader :
738
+ # A query only transitions to a FINISHED state when the results are fully consumed:
739
+ # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
740
+ result_downloader . submit ( self ._query .fetch )
741
+ while not self . _query . finished or self ._rows is not None :
742
+ next_rows = result_downloader . queue . get () if not self ._query . finished else None
743
+ if not self . _query . finished :
744
+ result_downloader . submit ( self . _query . fetch )
722
745
723
- self ._rows = next_rows
746
+ for row in self ._rows :
747
+ self ._rownumber += 1
748
+ logger .debug ("row %s" , row )
749
+ yield row
750
+
751
+ self ._rows = next_rows
724
752
725
753
726
754
class TrinoQuery (object ):
@@ -753,7 +781,7 @@ def columns(self):
753
781
while not self ._columns and not self .finished and not self .cancelled :
754
782
# Columns are not returned immediately after query is submitted.
755
783
# Continue fetching data until columns information is available and push fetched rows into buffer.
756
- self ._result .rows += self .fetch ()
784
+ self ._result .rows += self .map_rows ( self . fetch () )
757
785
return self ._columns
758
786
759
787
@property
@@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
802
830
803
831
# Execute should block until at least one row is received or query is finished or cancelled
804
832
while not self .finished and not self .cancelled and len (self ._result .rows ) == 0 :
805
- self ._result .rows += self .fetch ()
833
+ self ._result .rows += self .map_rows ( self . fetch () )
806
834
return self ._result
807
835
808
836
def _update_state (self , status ):
@@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]:
822
850
logger .debug (status )
823
851
if status .next_uri is None :
824
852
self ._finished = True
853
+ return status .rows
825
854
855
+ def map_rows (self , rows : List [List [Any ]]) -> List [List [Any ]]:
826
856
if not self ._row_mapper :
827
857
return []
828
-
829
- return self ._row_mapper .map (status .rows )
858
+ return self ._row_mapper .map (rows )
830
859
831
860
def cancel (self ) -> None :
832
861
"""Cancel the current query"""
0 commit comments