Skip to content

Commit 5a34a4a

Browse files
authored
Cloud fetch queue and integration (#151)
* Cloud fetch queue and integration Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Enable cloudfetch with direct results Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Typing and style changes Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Client-settable max_download_threads Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Docstrings and comments Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Increase default buffer size bytes to 104857600 Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Move max_download_threads to kwargs of ThriftBackend, fix unit tests Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Fix tests: staticmethod make_arrow_table mock not callable Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * cancel_futures in shutdown() only available in python >=3.9.0 Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Black linting Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> * Fix typing errors Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --------- Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
1 parent 01b7a8d commit 5a34a4a

File tree

6 files changed

+596
-136
lines changed

6 files changed

+596
-136
lines changed

src/databricks/sql/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
logger = logging.getLogger(__name__)
1919

20-
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
20+
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
2121
DEFAULT_ARRAY_SIZE = 100000
2222

2323

@@ -153,6 +153,8 @@ def read(self) -> Optional[OAuthToken]:
153153
# _use_arrow_native_timestamps
154154
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
155155
# (True by default)
156+
# use_cloud_fetch
157+
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
156158

157159
if access_token:
158160
access_token_kv = {"access_token": access_token}
@@ -189,6 +191,7 @@ def read(self) -> Optional[OAuthToken]:
189191
self._session_handle = self.thrift_backend.open_session(
190192
session_configuration, catalog, schema
191193
)
194+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False)
192195
self.open = True
193196
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
194197
self._cursors = [] # type: List[Cursor]
@@ -497,6 +500,7 @@ def execute(
497500
max_bytes=self.buffer_size_bytes,
498501
lz4_compression=self.connection.lz4_compression,
499502
cursor=self,
503+
use_cloud_fetch=self.connection.use_cloud_fetch,
500504
)
501505
self.active_result_set = ResultSet(
502506
self.connection,
@@ -822,6 +826,7 @@ def __iter__(self):
822826
break
823827

824828
def _fill_results_buffer(self):
829+
# At initialization or if the server does not have cloud fetch result links available
825830
results, has_more_rows = self.thrift_backend.fetch_results(
826831
op_handle=self.command_id,
827832
max_rows=self.arraysize,

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
161161
return True
162162

163163
def _shutdown_manager(self):
164-
# Clear download handlers and shutdown the thread pool to cancel pending futures
164+
# Clear download handlers and shutdown the thread pool
165165
self.download_handlers = []
166-
self.thread_pool.shutdown(wait=False, cancel_futures=True)
166+
self.thread_pool.shutdown(wait=False)

src/databricks/sql/thrift_backend.py

Lines changed: 39 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import uuid
77
import threading
8-
import lz4.frame
98
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
109
from typing import List, Union
1110

@@ -26,11 +25,14 @@
2625
)
2726

2827
from databricks.sql.utils import (
29-
ArrowQueue,
3028
ExecuteResponse,
3129
_bound,
3230
RequestErrorInfo,
3331
NoRetryReason,
32+
ResultSetQueueFactory,
33+
convert_arrow_based_set_to_arrow_table,
34+
convert_decimals_in_arrow_table,
35+
convert_column_based_set_to_arrow_table,
3436
)
3537

3638
logger = logging.getLogger(__name__)
@@ -67,7 +69,6 @@
6769
class ThriftBackend:
6870
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
6971
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
70-
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
7172

7273
def __init__(
7374
self,
@@ -115,6 +116,8 @@ def __init__(
115116
# _socket_timeout
116117
# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
117118
# (defaults to 900)
119+
# max_download_threads
120+
# Number of threads for handling cloud fetch downloads. Defaults to 10
118121

119122
port = port or 443
120123
if kwargs.get("_connection_uri"):
@@ -136,6 +139,9 @@ def __init__(
136139
"_use_arrow_native_timestamps", True
137140
)
138141

142+
# Cloud fetch
143+
self.max_download_threads = kwargs.get("max_download_threads", 10)
144+
139145
# Configure tls context
140146
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
141147
if kwargs.get("_tls_no_verify") is True:
@@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558564
(
559565
arrow_table,
560566
num_rows,
561-
) = ThriftBackend._convert_column_based_set_to_arrow_table(
562-
t_row_set.columns, description
563-
)
567+
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
564568
elif t_row_set.arrowBatches is not None:
565-
(
566-
arrow_table,
567-
num_rows,
568-
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
569+
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
569570
t_row_set.arrowBatches, lz4_compressed, schema_bytes
570571
)
571572
else:
572573
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
573-
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows
574-
575-
@staticmethod
576-
def _convert_decimals_in_arrow_table(table, description):
577-
for (i, col) in enumerate(table.itercolumns()):
578-
if description[i][1] == "decimal":
579-
decimal_col = col.to_pandas().apply(
580-
lambda v: v if v is None else Decimal(v)
581-
)
582-
precision, scale = description[i][4], description[i][5]
583-
assert scale is not None
584-
assert precision is not None
585-
# Spark limits decimal to a maximum scale of 38,
586-
# so 128 is guaranteed to be big enough
587-
dtype = pyarrow.decimal128(precision, scale)
588-
col_data = pyarrow.array(decimal_col, type=dtype)
589-
field = table.field(i).with_type(dtype)
590-
table = table.set_column(i, field, col_data)
591-
return table
592-
593-
@staticmethod
594-
def _convert_arrow_based_set_to_arrow_table(
595-
arrow_batches, lz4_compressed, schema_bytes
596-
):
597-
ba = bytearray()
598-
ba += schema_bytes
599-
n_rows = 0
600-
if lz4_compressed:
601-
for arrow_batch in arrow_batches:
602-
n_rows += arrow_batch.rowCount
603-
ba += lz4.frame.decompress(arrow_batch.batch)
604-
else:
605-
for arrow_batch in arrow_batches:
606-
n_rows += arrow_batch.rowCount
607-
ba += arrow_batch.batch
608-
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
609-
return arrow_table, n_rows
610-
611-
@staticmethod
612-
def _convert_column_based_set_to_arrow_table(columns, description):
613-
arrow_table = pyarrow.Table.from_arrays(
614-
[ThriftBackend._convert_column_to_arrow_array(c) for c in columns],
615-
# Only use the column names from the schema, the types are determined by the
616-
# physical types used in column based set, as they can differ from the
617-
# mapping used in _hive_schema_to_arrow_schema.
618-
names=[c[0] for c in description],
619-
)
620-
return arrow_table, arrow_table.num_rows
621-
622-
@staticmethod
623-
def _convert_column_to_arrow_array(t_col):
624-
"""
625-
Return a pyarrow array from the values in a TColumn instance.
626-
Note that ColumnBasedSet has no native support for complex types, so they will be converted
627-
to strings server-side.
628-
"""
629-
field_name_to_arrow_type = {
630-
"boolVal": pyarrow.bool_(),
631-
"byteVal": pyarrow.int8(),
632-
"i16Val": pyarrow.int16(),
633-
"i32Val": pyarrow.int32(),
634-
"i64Val": pyarrow.int64(),
635-
"doubleVal": pyarrow.float64(),
636-
"stringVal": pyarrow.string(),
637-
"binaryVal": pyarrow.binary(),
638-
}
639-
for field in field_name_to_arrow_type.keys():
640-
wrapper = getattr(t_col, field)
641-
if wrapper:
642-
return ThriftBackend._create_arrow_array(
643-
wrapper, field_name_to_arrow_type[field]
644-
)
645-
646-
raise OperationalError("Empty TColumn instance {}".format(t_col))
647-
648-
@staticmethod
649-
def _create_arrow_array(t_col_value_wrapper, arrow_type):
650-
result = t_col_value_wrapper.values
651-
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
652-
assert isinstance(nulls, bytes)
653-
654-
# The number of bits in nulls can be both larger or smaller than the number of
655-
# elements in result, so take the minimum of both to iterate over.
656-
length = min(len(result), len(nulls) * 8)
657-
658-
for i in range(length):
659-
if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]:
660-
result[i] = None
661-
662-
return pyarrow.array(result, type=arrow_type)
574+
return convert_decimals_in_arrow_table(arrow_table, description), num_rows
663575

664576
def _get_metadata_resp(self, op_handle):
665577
req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
@@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752664
if t_result_set_metadata_resp.resultFormat not in [
753665
ttypes.TSparkRowSetType.ARROW_BASED_SET,
754666
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
667+
ttypes.TSparkRowSetType.URL_BASED_SET,
755668
]:
756669
raise OperationalError(
757670
"Expected results to be in Arrow or column based format, "
@@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
783696
assert direct_results.resultSet.results.startRowOffset == 0
784697
assert direct_results.resultSetMetadata
785698

786-
arrow_results, n_rows = self._create_arrow_table(
787-
direct_results.resultSet.results,
788-
lz4_compressed,
789-
schema_bytes,
790-
description,
699+
arrow_queue_opt = ResultSetQueueFactory.build_queue(
700+
row_set_type=t_result_set_metadata_resp.resultFormat,
701+
t_row_set=direct_results.resultSet.results,
702+
arrow_schema_bytes=schema_bytes,
703+
max_download_threads=self.max_download_threads,
704+
lz4_compressed=lz4_compressed,
705+
description=description,
791706
)
792-
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
793707
else:
794708
arrow_queue_opt = None
795709
return ExecuteResponse(
@@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
843757
)
844758

845759
def execute_command(
846-
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
760+
self,
761+
operation,
762+
session_handle,
763+
max_rows,
764+
max_bytes,
765+
lz4_compression,
766+
cursor,
767+
use_cloud_fetch=False,
847768
):
848769
assert session_handle is not None
849770

@@ -864,7 +785,7 @@ def execute_command(
864785
),
865786
canReadArrowResult=True,
866787
canDecompressLZ4Result=lz4_compression,
867-
canDownloadResult=False,
788+
canDownloadResult=use_cloud_fetch,
868789
confOverlay={
869790
# We want to receive proper Timestamp arrow types.
870791
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
@@ -993,6 +914,7 @@ def fetch_results(
993914
maxRows=max_rows,
994915
maxBytes=max_bytes,
995916
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
917+
includeResultSetMetadata=True,
996918
)
997919

998920
resp = self.make_request(self._client.FetchResults, req)
@@ -1002,12 +924,17 @@ def fetch_results(
1002924
expected_row_start_offset, resp.results.startRowOffset
1003925
)
1004926
)
1005-
arrow_results, n_rows = self._create_arrow_table(
1006-
resp.results, lz4_compressed, arrow_schema_bytes, description
927+
928+
queue = ResultSetQueueFactory.build_queue(
929+
row_set_type=resp.resultSetMetadata.resultFormat,
930+
t_row_set=resp.results,
931+
arrow_schema_bytes=arrow_schema_bytes,
932+
max_download_threads=self.max_download_threads,
933+
lz4_compressed=lz4_compressed,
934+
description=description,
1007935
)
1008-
arrow_queue = ArrowQueue(arrow_results, n_rows)
1009936

1010-
return arrow_queue, resp.hasMoreRows
937+
return queue, resp.hasMoreRows
1011938

1012939
def close_command(self, op_handle):
1013940
req = ttypes.TCloseOperationReq(operationHandle=op_handle)

0 commit comments

Comments
 (0)