Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3bfbac1
feat: Implement host-level telemetry batching to reduce rate limiting
samikshya-db Dec 1, 2025
5e29089
chore: Change all telemetry logging to DEBUG level
samikshya-db Dec 1, 2025
26b67c7
chore: Fix remaining telemetry warning log to debug
samikshya-db Dec 1, 2025
da13b9e
fix: Update tests to use host_url instead of session_id_hex
samikshya-db Dec 1, 2025
ffc7b27
fix: Revert session_id_hex in tests for functions that still use it
samikshya-db Dec 1, 2025
b11a461
fix: Update all Error raises and test calls to use host_url
samikshya-db Dec 1, 2025
60e50de
fix: Update thrift_backend.py to use host_url instead of session_id_hex
samikshya-db Dec 1, 2025
01ea1e1
Fix Black formatting by adjusting fmt directive placement
samikshya-db Dec 1, 2025
9026f37
Fix telemetry feature flag tests to set mock session host
samikshya-db Dec 1, 2025
86d7828
Add teardown_method to clear telemetry factory state between tests
samikshya-db Dec 2, 2025
91b8382
Clear feature flag context cache in teardown to fix test pollution
samikshya-db Dec 2, 2025
74821f8
fix: Access actual client from holder in flush worker
samikshya-db Dec 2, 2025
0b3dd82
Clear telemetry client cache in e2e test teardown
samikshya-db Dec 2, 2025
4624458
Pass session_id parameter to telemetry export methods
samikshya-db Dec 2, 2025
8cb66ec
Fix Black formatting in telemetry_client.py
samikshya-db Dec 2, 2025
69789ee
Use 'test-host' instead of 'test' for mock host in telemetry tests
samikshya-db Dec 2, 2025
b0aa889
Replace test-session-id with test-host in test_client.py
samikshya-db Dec 2, 2025
c8cfc23
Fix telemetry client lookup to use test-host in tests
samikshya-db Dec 2, 2025
962def5
Make session_id_hex keyword-only parameter in Error.__init__
samikshya-db Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 34 additions & 38 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
else:
raise ValueError("No valid connection settings.")

self._host = server_hostname
self._initialize_retry_args(kwargs)
self._use_arrow_native_complex_types = kwargs.get(
"_use_arrow_native_complex_types", True
Expand Down Expand Up @@ -279,14 +280,14 @@ def _initialize_retry_args(self, kwargs):
)

@staticmethod
def _check_response_for_error(response, session_id_hex=None):
def _check_response_for_error(response, host_url=None):
if response.status and response.status.statusCode in [
ttypes.TStatusCode.ERROR_STATUS,
ttypes.TStatusCode.INVALID_HANDLE_STATUS,
]:
raise DatabaseError(
response.status.errorMessage,
session_id_hex=session_id_hex,
host_url=host_url,
)

@staticmethod
Expand Down Expand Up @@ -340,7 +341,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
network_request_error = RequestError(
user_friendly_error_message,
full_error_info_context,
self._session_id_hex,
self._host,
error_info.error,
)
logger.info(network_request_error.message_with_context())
Expand Down Expand Up @@ -461,13 +462,12 @@ def attempt_request(attempt):
errno.ECONNRESET, # | 104 | 54 |
errno.ETIMEDOUT, # | 110 | 60 |
]
# fmt: on

gos_name = TCLIServiceClient.GetOperationStatus.__name__
# retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet
if method.__name__ == gos_name or err.errno == errno.ETIMEDOUT:
retry_delay = bound_retry_delay(attempt, self._retry_delay_default)

# fmt: on
log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry"
if err.errno in info_errs:
logger.info(log_string)
Expand Down Expand Up @@ -516,9 +516,7 @@ def attempt_request(attempt):
if not isinstance(response_or_error_info, RequestErrorInfo):
# log nothing here, presume that main request logging covers
response = response_or_error_info
ThriftDatabricksClient._check_response_for_error(
response, self._session_id_hex
)
ThriftDatabricksClient._check_response_for_error(response, self._host)
return response

error_info = response_or_error_info
Expand All @@ -533,7 +531,7 @@ def _check_protocol_version(self, t_open_session_resp):
"Error: expected server to use a protocol version >= "
"SPARK_CLI_SERVICE_PROTOCOL_V2, "
"instead got: {}".format(protocol_version),
session_id_hex=self._session_id_hex,
host_url=self._host,
)

def _check_initial_namespace(self, catalog, schema, response):
Expand All @@ -547,15 +545,15 @@ def _check_initial_namespace(self, catalog, schema, response):
raise InvalidServerResponseError(
"Setting initial namespace not supported by the DBR version, "
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.",
session_id_hex=self._session_id_hex,
host_url=self._host,
)

if catalog:
if not response.canUseMultipleCatalogs:
raise InvalidServerResponseError(
"Unexpected response from server: Trying to set initial catalog to {}, "
+ "but server does not support multiple catalogs.".format(catalog), # type: ignore
session_id_hex=self._session_id_hex,
host_url=self._host,
)

def _check_session_configuration(self, session_configuration):
Expand All @@ -570,7 +568,7 @@ def _check_session_configuration(self, session_configuration):
TIMESTAMP_AS_STRING_CONFIG,
session_configuration[TIMESTAMP_AS_STRING_CONFIG],
),
session_id_hex=self._session_id_hex,
host_url=self._host,
)

def open_session(self, session_configuration, catalog, schema) -> SessionId:
Expand Down Expand Up @@ -639,7 +637,7 @@ def _check_command_not_in_error_or_closed_state(
and guid_to_hex_id(op_handle.operationId.guid),
"diagnostic-info": get_operations_resp.diagnosticInfo,
},
session_id_hex=self._session_id_hex,
host_url=self._host,
)
else:
raise ServerOperationError(
Expand All @@ -649,7 +647,7 @@ def _check_command_not_in_error_or_closed_state(
and guid_to_hex_id(op_handle.operationId.guid),
"diagnostic-info": None,
},
session_id_hex=self._session_id_hex,
host_url=self._host,
)
elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE:
raise DatabaseError(
Expand All @@ -660,7 +658,7 @@ def _check_command_not_in_error_or_closed_state(
"operation-id": op_handle
and guid_to_hex_id(op_handle.operationId.guid)
},
session_id_hex=self._session_id_hex,
host_url=self._host,
)

def _poll_for_status(self, op_handle):
Expand All @@ -683,7 +681,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
else:
raise OperationalError(
"Unsupported TRowSet instance {}".format(t_row_set),
session_id_hex=self._session_id_hex,
host_url=self._host,
)
return convert_decimals_in_arrow_table(arrow_table, description), num_rows

Expand All @@ -692,7 +690,7 @@ def _get_metadata_resp(self, op_handle):
return self.make_request(self._client.GetResultSetMetadata, req)

@staticmethod
def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None):
def _hive_schema_to_arrow_schema(t_table_schema, host_url=None):
def map_type(t_type_entry):
if t_type_entry.primitiveEntry:
return {
Expand Down Expand Up @@ -724,7 +722,7 @@ def map_type(t_type_entry):
# even for complex types
raise OperationalError(
"Thrift protocol error: t_type_entry not a primitiveEntry",
session_id_hex=session_id_hex,
host_url=host_url,
)

def convert_col(t_column_desc):
Expand All @@ -735,7 +733,7 @@ def convert_col(t_column_desc):
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])

@staticmethod
def _col_to_description(col, field=None, session_id_hex=None):
def _col_to_description(col, field=None, host_url=None):
type_entry = col.typeDesc.types[0]

if type_entry.primitiveEntry:
Expand All @@ -745,7 +743,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
else:
raise OperationalError(
"Thrift protocol error: t_type_entry not a primitiveEntry",
session_id_hex=session_id_hex,
host_url=host_url,
)

if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
Expand All @@ -759,7 +757,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
raise OperationalError(
"Decimal type did not provide typeQualifier precision, scale in "
"primitiveEntry {}".format(type_entry.primitiveEntry),
session_id_hex=session_id_hex,
host_url=host_url,
)
else:
precision, scale = None, None
Expand All @@ -778,9 +776,7 @@ def _col_to_description(col, field=None, session_id_hex=None):
return col.columnName, cleaned_type, None, None, precision, scale, None

@staticmethod
def _hive_schema_to_description(
t_table_schema, schema_bytes=None, session_id_hex=None
):
def _hive_schema_to_description(t_table_schema, schema_bytes=None, host_url=None):
field_dict = {}
if pyarrow and schema_bytes:
try:
Expand All @@ -795,7 +791,7 @@ def _hive_schema_to_description(
ThriftDatabricksClient._col_to_description(
col,
field_dict.get(col.columnName) if field_dict else None,
session_id_hex,
host_url,
)
for col in t_table_schema.columns
]
Expand All @@ -818,7 +814,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
t_result_set_metadata_resp.resultFormat
]
),
session_id_hex=self._session_id_hex,
host_url=self._host,
)
direct_results = resp.directResults
has_been_closed_server_side = direct_results and direct_results.closeOperation
Expand All @@ -833,7 +829,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(
t_result_set_metadata_resp.schema, self._session_id_hex
t_result_set_metadata_resp.schema, self._host
)
.serialize()
.to_pybytes()
Expand All @@ -844,7 +840,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
schema_bytes,
self._session_id_hex,
self._host,
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
Expand Down Expand Up @@ -895,7 +891,7 @@ def get_execution_result(
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(
t_result_set_metadata_resp.schema, self._session_id_hex
t_result_set_metadata_resp.schema, self._host
)
.serialize()
.to_pybytes()
Expand All @@ -906,7 +902,7 @@ def get_execution_result(
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
schema_bytes,
self._session_id_hex,
self._host,
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
Expand Down Expand Up @@ -971,27 +967,27 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
return state

@staticmethod
def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None):
def _check_direct_results_for_error(t_spark_direct_results, host_url=None):
if t_spark_direct_results:
if t_spark_direct_results.operationStatus:
ThriftDatabricksClient._check_response_for_error(
t_spark_direct_results.operationStatus,
session_id_hex,
host_url,
)
if t_spark_direct_results.resultSetMetadata:
ThriftDatabricksClient._check_response_for_error(
t_spark_direct_results.resultSetMetadata,
session_id_hex,
host_url,
)
if t_spark_direct_results.resultSet:
ThriftDatabricksClient._check_response_for_error(
t_spark_direct_results.resultSet,
session_id_hex,
host_url,
)
if t_spark_direct_results.closeOperation:
ThriftDatabricksClient._check_response_for_error(
t_spark_direct_results.closeOperation,
session_id_hex,
host_url,
)

def execute_command(
Expand Down Expand Up @@ -1260,7 +1256,7 @@ def _handle_execute_response(self, resp, cursor):
raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}")

cursor.active_command_id = command_id
self._check_direct_results_for_error(resp.directResults, self._session_id_hex)
self._check_direct_results_for_error(resp.directResults, self._host)

final_operation_state = self._wait_until_command_done(
resp.operationHandle,
Expand All @@ -1275,7 +1271,7 @@ def _handle_execute_response_async(self, resp, cursor):
raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}")

cursor.active_command_id = command_id
self._check_direct_results_for_error(resp.directResults, self._session_id_hex)
self._check_direct_results_for_error(resp.directResults, self._host)

def fetch_results(
self,
Expand Down Expand Up @@ -1313,7 +1309,7 @@ def fetch_results(
"fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format(
expected_row_start_offset, resp.results.startRowOffset
),
session_id_hex=self._session_id_hex,
host_url=self._host,
)

queue = ThriftResultSetQueueFactory.build_queue(
Expand Down
Loading
Loading