55import time
66import uuid
77import threading
8- import lz4 .frame
98from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
109from typing import List , Union
1110
2625)
2726
2827from databricks .sql .utils import (
29- ArrowQueue ,
3028 ExecuteResponse ,
3129 _bound ,
3230 RequestErrorInfo ,
3331 NoRetryReason ,
32+ ResultSetQueueFactory ,
33+ convert_arrow_based_set_to_arrow_table ,
34+ convert_decimals_in_arrow_table ,
35+ convert_column_based_set_to_arrow_table ,
3436)
3537
3638logger = logging .getLogger (__name__ )
6769class ThriftBackend :
6870 CLOSED_OP_STATE = ttypes .TOperationState .CLOSED_STATE
6971 ERROR_OP_STATE = ttypes .TOperationState .ERROR_STATE
70- BIT_MASKS = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]
7172
7273 def __init__ (
7374 self ,
@@ -115,6 +116,8 @@ def __init__(
115116 # _socket_timeout
116117 # The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
117118 # (defaults to 900)
119+ # max_download_threads
120+ # Number of threads for handling cloud fetch downloads. Defaults to 10
118121
119122 port = port or 443
120123 if kwargs .get ("_connection_uri" ):
@@ -136,6 +139,9 @@ def __init__(
136139 "_use_arrow_native_timestamps" , True
137140 )
138141
142+ # Cloud fetch
143+ self .max_download_threads = kwargs .get ("max_download_threads" , 10 )
144+
139145 # Configure tls context
140146 ssl_context = create_default_context (cafile = kwargs .get ("_tls_trusted_ca_file" ))
141147 if kwargs .get ("_tls_no_verify" ) is True :
@@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558564 (
559565 arrow_table ,
560566 num_rows ,
561- ) = ThriftBackend ._convert_column_based_set_to_arrow_table (
562- t_row_set .columns , description
563- )
567+ ) = convert_column_based_set_to_arrow_table (t_row_set .columns , description )
564568 elif t_row_set .arrowBatches is not None :
565- (
566- arrow_table ,
567- num_rows ,
568- ) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
569+ (arrow_table , num_rows ,) = convert_arrow_based_set_to_arrow_table (
569570 t_row_set .arrowBatches , lz4_compressed , schema_bytes
570571 )
571572 else :
572573 raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
573- return self ._convert_decimals_in_arrow_table (arrow_table , description ), num_rows
574-
575- @staticmethod
576- def _convert_decimals_in_arrow_table (table , description ):
577- for (i , col ) in enumerate (table .itercolumns ()):
578- if description [i ][1 ] == "decimal" :
579- decimal_col = col .to_pandas ().apply (
580- lambda v : v if v is None else Decimal (v )
581- )
582- precision , scale = description [i ][4 ], description [i ][5 ]
583- assert scale is not None
584- assert precision is not None
585- # Spark limits decimal to a maximum scale of 38,
586- # so 128 is guaranteed to be big enough
587- dtype = pyarrow .decimal128 (precision , scale )
588- col_data = pyarrow .array (decimal_col , type = dtype )
589- field = table .field (i ).with_type (dtype )
590- table = table .set_column (i , field , col_data )
591- return table
592-
593- @staticmethod
594- def _convert_arrow_based_set_to_arrow_table (
595- arrow_batches , lz4_compressed , schema_bytes
596- ):
597- ba = bytearray ()
598- ba += schema_bytes
599- n_rows = 0
600- if lz4_compressed :
601- for arrow_batch in arrow_batches :
602- n_rows += arrow_batch .rowCount
603- ba += lz4 .frame .decompress (arrow_batch .batch )
604- else :
605- for arrow_batch in arrow_batches :
606- n_rows += arrow_batch .rowCount
607- ba += arrow_batch .batch
608- arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
609- return arrow_table , n_rows
610-
611- @staticmethod
612- def _convert_column_based_set_to_arrow_table (columns , description ):
613- arrow_table = pyarrow .Table .from_arrays (
614- [ThriftBackend ._convert_column_to_arrow_array (c ) for c in columns ],
615- # Only use the column names from the schema, the types are determined by the
616- # physical types used in column based set, as they can differ from the
617- # mapping used in _hive_schema_to_arrow_schema.
618- names = [c [0 ] for c in description ],
619- )
620- return arrow_table , arrow_table .num_rows
621-
622- @staticmethod
623- def _convert_column_to_arrow_array (t_col ):
624- """
625- Return a pyarrow array from the values in a TColumn instance.
626- Note that ColumnBasedSet has no native support for complex types, so they will be converted
627- to strings server-side.
628- """
629- field_name_to_arrow_type = {
630- "boolVal" : pyarrow .bool_ (),
631- "byteVal" : pyarrow .int8 (),
632- "i16Val" : pyarrow .int16 (),
633- "i32Val" : pyarrow .int32 (),
634- "i64Val" : pyarrow .int64 (),
635- "doubleVal" : pyarrow .float64 (),
636- "stringVal" : pyarrow .string (),
637- "binaryVal" : pyarrow .binary (),
638- }
639- for field in field_name_to_arrow_type .keys ():
640- wrapper = getattr (t_col , field )
641- if wrapper :
642- return ThriftBackend ._create_arrow_array (
643- wrapper , field_name_to_arrow_type [field ]
644- )
645-
646- raise OperationalError ("Empty TColumn instance {}" .format (t_col ))
647-
648- @staticmethod
649- def _create_arrow_array (t_col_value_wrapper , arrow_type ):
650- result = t_col_value_wrapper .values
651- nulls = t_col_value_wrapper .nulls # bitfield describing which values are null
652- assert isinstance (nulls , bytes )
653-
654- # The number of bits in nulls can be both larger or smaller than the number of
655- # elements in result, so take the minimum of both to iterate over.
656- length = min (len (result ), len (nulls ) * 8 )
657-
658- for i in range (length ):
659- if nulls [i >> 3 ] & ThriftBackend .BIT_MASKS [i & 0x7 ]:
660- result [i ] = None
661-
662- return pyarrow .array (result , type = arrow_type )
574+ return convert_decimals_in_arrow_table (arrow_table , description ), num_rows
663575
664576 def _get_metadata_resp (self , op_handle ):
665577 req = ttypes .TGetResultSetMetadataReq (operationHandle = op_handle )
@@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752664 if t_result_set_metadata_resp .resultFormat not in [
753665 ttypes .TSparkRowSetType .ARROW_BASED_SET ,
754666 ttypes .TSparkRowSetType .COLUMN_BASED_SET ,
667+ ttypes .TSparkRowSetType .URL_BASED_SET ,
755668 ]:
756669 raise OperationalError (
757670 "Expected results to be in Arrow or column based format, "
@@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
783696 assert direct_results .resultSet .results .startRowOffset == 0
784697 assert direct_results .resultSetMetadata
785698
786- arrow_results , n_rows = self ._create_arrow_table (
787- direct_results .resultSet .results ,
788- lz4_compressed ,
789- schema_bytes ,
790- description ,
699+ arrow_queue_opt = ResultSetQueueFactory .build_queue (
700+ row_set_type = t_result_set_metadata_resp .resultFormat ,
701+ t_row_set = direct_results .resultSet .results ,
702+ arrow_schema_bytes = schema_bytes ,
703+ max_download_threads = self .max_download_threads ,
704+ lz4_compressed = lz4_compressed ,
705+ description = description ,
791706 )
792- arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
793707 else :
794708 arrow_queue_opt = None
795709 return ExecuteResponse (
@@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
843757 )
844758
845759 def execute_command (
846- self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor
760+ self ,
761+ operation ,
762+ session_handle ,
763+ max_rows ,
764+ max_bytes ,
765+ lz4_compression ,
766+ cursor ,
767+ use_cloud_fetch = False ,
847768 ):
848769 assert session_handle is not None
849770
@@ -864,7 +785,7 @@ def execute_command(
864785 ),
865786 canReadArrowResult = True ,
866787 canDecompressLZ4Result = lz4_compression ,
867- canDownloadResult = False ,
788+ canDownloadResult = use_cloud_fetch ,
868789 confOverlay = {
869790 # We want to receive proper Timestamp arrow types.
870791 "spark.thriftserver.arrowBasedRowSet.timestampAsString" : "false"
@@ -993,6 +914,7 @@ def fetch_results(
993914 maxRows = max_rows ,
994915 maxBytes = max_bytes ,
995916 orientation = ttypes .TFetchOrientation .FETCH_NEXT ,
917+ includeResultSetMetadata = True ,
996918 )
997919
998920 resp = self .make_request (self ._client .FetchResults , req )
@@ -1002,12 +924,17 @@ def fetch_results(
1002924 expected_row_start_offset , resp .results .startRowOffset
1003925 )
1004926 )
1005- arrow_results , n_rows = self ._create_arrow_table (
1006- resp .results , lz4_compressed , arrow_schema_bytes , description
927+
928+ queue = ResultSetQueueFactory .build_queue (
929+ row_set_type = resp .resultSetMetadata .resultFormat ,
930+ t_row_set = resp .results ,
931+ arrow_schema_bytes = arrow_schema_bytes ,
932+ max_download_threads = self .max_download_threads ,
933+ lz4_compressed = lz4_compressed ,
934+ description = description ,
1007935 )
1008- arrow_queue = ArrowQueue (arrow_results , n_rows )
1009936
1010- return arrow_queue , resp .hasMoreRows
937+ return queue , resp .hasMoreRows
1011938
1012939 def close_command (self , op_handle ):
1013940 req = ttypes .TCloseOperationReq (operationHandle = op_handle )
0 commit comments