3939import copy
4040import functools
4141import os
42+ import queue
4243import random
4344import re
4445import threading
4546import urllib .parse
47+ from concurrent .futures import ThreadPoolExecutor
4648from datetime import date , datetime , time , timedelta , timezone , tzinfo
4749from decimal import Decimal
4850from time import sleep
49- from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
51+ from typing import Any , Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
5052
5153import pytz
5254import requests
@@ -666,6 +668,27 @@ def _verify_extra_credential(self, header):
666668 raise ValueError (f"only ASCII characters are allowed in extra credential '{ key } '" )
667669
668670
671+ class ResultDownloader ():
672+ def __init__ (self ):
673+ self .queue : queue .Queue = queue .Queue ()
674+ self .executor : Optional [ThreadPoolExecutor ] = None
675+
676+ def submit (self , fetch_func : Callable [[], List [Any ]]):
677+ assert self .executor is not None
678+ self .executor .submit (self .download_task , fetch_func )
679+
680+ def download_task (self , fetch_func ):
681+ self .queue .put (fetch_func ())
682+
683+ def __enter__ (self ):
684+ self .executor = ThreadPoolExecutor (max_workers = 1 )
685+ return self
686+
687+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
688+ self .executor .shutdown ()
689+ self .executor = None
690+
691+
669692class TrinoResult (object ):
670693 """
671694 Represent the result of a Trino query as an iterator on rows.
@@ -693,16 +716,21 @@ def rownumber(self) -> int:
693716 return self ._rownumber
694717
695718 def __iter__ (self ):
696- # A query only transitions to a FINISHED state when the results are fully consumed :
697- # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
698- while not self . _query . finished or self . _rows is not None :
699- next_rows = self . _query . fetch () if not self ._query .finished else None
700- for row in self ._rows :
701- self ._rownumber += 1
702- logger . debug ( "row %s" , row )
703- yield row
719+ with ResultDownloader () as result_downloader :
720+ # A query only transitions to a FINISHED state when the results are fully consumed:
721+ # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
722+ result_downloader . submit ( self ._query .fetch )
723+ while not self . _query . finished or self ._rows is not None :
724+ next_rows = result_downloader . queue . get () if not self ._query . finished else None
725+ if not self . _query . finished :
726+ result_downloader . submit ( self . _query . fetch )
704727
705- self ._rows = next_rows
728+ for row in self ._rows :
729+ self ._rownumber += 1
730+ logger .debug ("row %s" , row )
731+ yield row
732+
733+ self ._rows = next_rows
706734
707735
708736class TrinoQuery (object ):
@@ -735,7 +763,7 @@ def columns(self):
735763 while not self ._columns and not self .finished and not self .cancelled :
736764 # Columns are not returned immediately after query is submitted.
737765 # Continue fetching data until columns information is available and push fetched rows into buffer.
738- self ._result .rows += self .fetch ()
766+ self ._result .rows += self .map_rows ( self . fetch () )
739767 return self ._columns
740768
741769 @property
@@ -784,7 +812,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
784812
785813 # Execute should block until at least one row is received or query is finished or cancelled
786814 while not self .finished and not self .cancelled and len (self ._result .rows ) == 0 :
787- self ._result .rows += self .fetch ()
815+ self ._result .rows += self .map_rows ( self . fetch () )
788816 return self ._result
789817
790818 def _update_state (self , status ):
@@ -796,19 +824,20 @@ def _update_state(self, status):
796824 if status .columns :
797825 self ._columns = status .columns
798826
799- def fetch (self ) -> List [List [ Any ] ]:
827+ def fetch (self ) -> List [Any ]:
800828 """Continue fetching data for the current query_id"""
801829 response = self ._request .get (self ._request .next_uri )
802830 status = self ._request .process (response )
803831 self ._update_state (status )
804832 logger .debug (status )
805833 if status .next_uri is None :
806834 self ._finished = True
835+ return status .rows
807836
837+ def map_rows (self , rows : List [List [Any ]]) -> List [List [Any ]]:
808838 if not self ._row_mapper :
809839 return []
810-
811- return self ._row_mapper .map (status .rows )
840+ return self ._row_mapper .map (rows )
812841
813842 def cancel (self ) -> None :
814843 """Cancel the current query"""
0 commit comments