39
39
import copy
40
40
import functools
41
41
import os
42
+ import queue
42
43
import random
43
44
import re
44
45
import threading
45
46
import urllib .parse
47
+ from concurrent .futures import ThreadPoolExecutor
46
48
from datetime import date , datetime , time , timedelta , timezone , tzinfo
47
49
from decimal import Decimal
48
50
from time import sleep
49
- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
51
+ from typing import Any , Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
50
52
51
53
import pytz
52
54
import requests
@@ -666,6 +668,27 @@ def _verify_extra_credential(self, header):
666
668
raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
667
669
668
670
671
+ class ResultDownloader ():
672
+ def __init__ (self ):
673
+ self .queue : queue .Queue = queue .Queue ()
674
+ self .executor : Optional [ThreadPoolExecutor ] = None
675
+
676
+ def submit (self , fetch_func : Callable [[], List [Any ]]):
677
+ assert self .executor is not None
678
+ self .executor .submit (self .download_task , fetch_func )
679
+
680
+ def download_task (self , fetch_func ):
681
+ self .queue .put (fetch_func ())
682
+
683
+ def __enter__ (self ):
684
+ self .executor = ThreadPoolExecutor (max_workers = 1 )
685
+ return self
686
+
687
+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
688
+ self .executor .shutdown ()
689
+ self .executor = None
690
+
691
+
669
692
class TrinoResult (object ):
670
693
"""
671
694
Represent the result of a Trino query as an iterator on rows.
@@ -693,16 +716,21 @@ def rownumber(self) -> int:
693
716
return self ._rownumber
694
717
695
718
def __iter__ (self ):
696
- # A query only transitions to a FINISHED state when the results are fully consumed :
697
- # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
698
- while not self . _query . finished or self . _rows is not None :
699
- next_rows = self . _query . fetch () if not self ._query .finished else None
700
- for row in self ._rows :
701
- self ._rownumber += 1
702
- logger . debug ( "row %s" , row )
703
- yield row
719
+ with ResultDownloader () as result_downloader :
720
+ # A query only transitions to a FINISHED state when the results are fully consumed:
721
+ # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
722
+ result_downloader . submit ( self ._query .fetch )
723
+ while not self . _query . finished or self ._rows is not None :
724
+ next_rows = result_downloader . queue . get () if not self ._query . finished else None
725
+ if not self . _query . finished :
726
+ result_downloader . submit ( self . _query . fetch )
704
727
705
- self ._rows = next_rows
728
+ for row in self ._rows :
729
+ self ._rownumber += 1
730
+ logger .debug ("row %s" , row )
731
+ yield row
732
+
733
+ self ._rows = next_rows
706
734
707
735
708
736
class TrinoQuery (object ):
@@ -735,7 +763,7 @@ def columns(self):
735
763
while not self ._columns and not self .finished and not self .cancelled :
736
764
# Columns are not returned immediately after query is submitted.
737
765
# Continue fetching data until columns information is available and push fetched rows into buffer.
738
- self ._result .rows += self .fetch ()
766
+ self ._result .rows += self .map_rows ( self . fetch () )
739
767
return self ._columns
740
768
741
769
@property
@@ -784,7 +812,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
784
812
785
813
# Execute should block until at least one row is received or query is finished or cancelled
786
814
while not self .finished and not self .cancelled and len (self ._result .rows ) == 0 :
787
- self ._result .rows += self .fetch ()
815
+ self ._result .rows += self .map_rows ( self . fetch () )
788
816
return self ._result
789
817
790
818
def _update_state (self , status ):
@@ -796,19 +824,20 @@ def _update_state(self, status):
796
824
if status .columns :
797
825
self ._columns = status .columns
798
826
799
- def fetch (self ) -> List [List [ Any ] ]:
827
+ def fetch (self ) -> List [Any ]:
800
828
"""Continue fetching data for the current query_id"""
801
829
response = self ._request .get (self ._request .next_uri )
802
830
status = self ._request .process (response )
803
831
self ._update_state (status )
804
832
logger .debug (status )
805
833
if status .next_uri is None :
806
834
self ._finished = True
835
+ return status .rows
807
836
837
+ def map_rows (self , rows : List [List [Any ]]) -> List [List [Any ]]:
808
838
if not self ._row_mapper :
809
839
return []
810
-
811
- return self ._row_mapper .map (status .rows )
840
+ return self ._row_mapper .map (rows )
812
841
813
842
def cancel (self ) -> None :
814
843
"""Cancel the current query"""
0 commit comments