diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d2b10e718..edee02bfa 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -163,6 +163,7 @@ def __init__( else: raise ValueError("No valid connection settings.") + self._host = server_hostname self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -279,14 +280,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, session_id_hex=None): + def _check_response_for_error(response, host_url=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - session_id_hex=session_id_hex, + host_url=host_url, ) @staticmethod @@ -340,7 +341,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._session_id_hex, + self._host, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -461,13 +462,12 @@ def attempt_request(attempt): errno.ECONNRESET, # | 104 | 54 | errno.ETIMEDOUT, # | 110 | 60 | ] + # fmt: on gos_name = TCLIServiceClient.GetOperationStatus.__name__ # retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet if method.__name__ == gos_name or err.errno == errno.ETIMEDOUT: retry_delay = bound_retry_delay(attempt, self._retry_delay_default) - - # fmt: on log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry" if err.errno in info_errs: logger.info(log_string) @@ -516,9 +516,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftDatabricksClient._check_response_for_error( - response, self._session_id_hex - ) + ThriftDatabricksClient._check_response_for_error(response, self._host) return response error_info = response_or_error_info @@ -533,7 +531,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_initial_namespace(self, catalog, schema, response): @@ -547,7 +545,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - session_id_hex=self._session_id_hex, + host_url=self._host, ) if catalog: @@ -555,7 +553,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_session_configuration(self, session_configuration): @@ -570,7 +568,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def open_session(self, session_configuration, catalog, schema) -> SessionId: @@ -639,7 +637,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) else: raise ServerOperationError( @@ -649,7 +647,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -660,7 +658,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and guid_to_hex_id(op_handle.operationId.guid) }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _poll_for_status(self, op_handle): @@ -683,7 +681,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - session_id_hex=self._session_id_hex, + host_url=self._host, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -692,7 +690,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): + def _hive_schema_to_arrow_schema(t_table_schema, host_url=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -724,7 +722,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) def convert_col(t_column_desc): @@ -735,7 +733,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, field=None, session_id_hex=None): + def _col_to_description(col, field=None, host_url=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -745,7 +743,7 @@ def _col_to_description(col, field=None, session_id_hex=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -759,7 +757,7 @@ def _col_to_description(col, field=None, session_id_hex=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - session_id_hex=session_id_hex, + host_url=host_url, ) else: precision, scale = None, None @@ -778,9 +776,7 @@ def _col_to_description(col, field=None, session_id_hex=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description( - t_table_schema, schema_bytes=None, session_id_hex=None - ): + def _hive_schema_to_description(t_table_schema, schema_bytes=None, host_url=None): field_dict = {} if pyarrow and schema_bytes: try: @@ -795,7 +791,7 @@ def _hive_schema_to_description( ThriftDatabricksClient._col_to_description( col, field_dict.get(col.columnName) if field_dict else None, - session_id_hex, + host_url, ) for col in t_table_schema.columns ] @@ -818,7 +814,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -833,7 +829,7 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -844,7 +840,7 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -895,7 +891,7 @@ def get_execution_result( schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -906,7 +902,7 @@ def get_execution_result( description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -971,27 +967,27 @@ def get_query_state(self, command_id: CommandId) -> CommandState: return state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): + def _check_direct_results_for_error(t_spark_direct_results, host_url=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSetMetadata: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSet: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet, - session_id_hex, + host_url, ) if t_spark_direct_results.closeOperation: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation, - session_id_hex, + host_url, ) def execute_command( @@ -1260,7 +1256,7 @@ def _handle_execute_response(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1275,7 +1271,7 @@ def _handle_execute_response_async(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) def fetch_results( self, @@ -1313,7 +1309,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) queue = ThriftResultSetQueueFactory.build_queue( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a7f802dcd..c873700bc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -341,7 +341,7 @@ def read(self) -> Optional[OAuthToken]: ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex=self.get_session_id_hex() + host_url=self.session.host ) # Determine proxy usage @@ -391,6 +391,7 @@ def read(self) -> Optional[OAuthToken]: self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, + session_id=self.get_session_id_hex(), ) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): @@ -494,6 +495,7 @@ def cursor( if not self.open: raise InterfaceError( "Cannot create cursor from closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -521,7 +523,7 @@ def _close(self, close_cursors=True) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - TelemetryClientFactory.close(self.get_session_id_hex()) + TelemetryClientFactory.close(host_url=self.session.host) # Close HTTP client that was created by this connection if self.http_client: @@ -546,6 +548,7 @@ def autocommit(self) -> bool: if not self.open: raise InterfaceError( "Cannot get autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -578,6 +581,7 @@ def autocommit(self, value: bool) -> None: if not self.open: raise InterfaceError( "Cannot set autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -600,6 +604,7 @@ def autocommit(self, value: bool) -> None: "operation": "set_autocommit", "autocommit_value": value, }, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -627,6 +632,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( "No result returned from SET AUTOCOMMIT query", context={"operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -647,6 +653,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( f"Failed to fetch autocommit state from server: {e.message}", context={**e.context, "operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -680,6 +687,7 @@ def commit(self) -> None: if not self.open: raise InterfaceError( "Cannot commit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -692,6 +700,7 @@ def commit(self) -> None: raise TransactionError( f"Failed to commit transaction: {e.message}", context={**e.context, "operation": "commit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -725,12 +734,14 @@ def rollback(self) -> None: if self.ignore_transactions: raise NotSupportedError( "Transactions are not supported on Databricks", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) if not self.open: raise InterfaceError( "Cannot rollback on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -743,6 +754,7 @@ def rollback(self) -> None: raise TransactionError( f"Failed to rollback transaction: {e.message}", context={**e.context, "operation": "rollback"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -767,6 +779,7 @@ def get_transaction_isolation(self) -> str: if not self.open: raise InterfaceError( "Cannot get transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -793,6 +806,7 @@ def set_transaction_isolation(self, level: str) -> None: if not self.open: raise InterfaceError( "Cannot set transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -805,6 +819,7 @@ def set_transaction_isolation(self, level: str) -> None: raise NotSupportedError( f"Setting transaction isolation level '{level}' is not supported. " f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -857,6 +872,7 @@ def __iter__(self): else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -997,6 +1013,7 @@ def _check_not_closed(self): if not self.open: raise InterfaceError( "Attempting operation on closed cursor", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1041,6 +1058,7 @@ def _handle_staging_operation( else: raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1067,6 +1085,7 @@ def _handle_staging_operation( if not allow_operation: raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1095,6 +1114,7 @@ def _handle_staging_operation( raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1110,6 +1130,7 @@ def _handle_staging_put( if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1135,6 +1156,7 @@ def _handle_staging_http_response(self, r): error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1166,6 +1188,7 @@ def _handle_staging_put_stream( if not stream: raise ProgrammingError( "No input stream provided for streaming operation", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1187,6 +1210,7 @@ def _handle_staging_get( if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1201,6 +1225,7 @@ def _handle_staging_get( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1222,6 +1247,7 @@ def _handle_staging_remove( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1413,6 +1439,7 @@ def get_async_execution_result(self): else: raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1541,6 +1568,7 @@ def fetchall(self) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1558,6 +1586,7 @@ def fetchone(self) -> Optional[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1583,6 +1612,7 @@ def fetchmany(self, size: int) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1593,6 +1623,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1603,6 +1634,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 24844d573..f4770f3c4 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -12,20 +12,28 @@ class Error(Exception): """ def __init__( - self, message=None, context=None, session_id_hex=None, *args, **kwargs + self, + message=None, + context=None, + host_url=None, + *args, + session_id_hex=None, + **kwargs, ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ - if session_id_hex: + if host_url: from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=host_url + ) + telemetry_client.export_failure_log( + error_name, self.message, session_id=session_id_hex ) - telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py index 852f0d916..a5df7371e 100644 --- a/src/databricks/sql/telemetry/circuit_breaker_manager.py +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -60,16 +60,16 @@ def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: old_state_name = old_state.name if old_state else "None" new_state_name = new_state.name if new_state else "None" - logger.info( + logger.debug( LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name ) if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: - logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_OPENED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: - logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: - logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) class CircuitBreakerManager: diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 36ebee2b8..2445c25c2 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -205,13 +205,14 @@ def wrapper(self, *args, **kwargs): telemetry_client = ( TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=connection.session.host ) ) telemetry_client.export_latency_log( latency_ms=duration_ms, sql_execution_event=sql_exec_event, sql_statement_id=telemetry_data.get("statement_id"), + session_id=session_id_hex, ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d5f5b575c..77d1a2f9c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -147,13 +147,17 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): pass - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): pass - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): pass def close(self): @@ -307,7 +311,7 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): ) return response except Exception as e: - logger.error("Failed to send telemetry with unified client: %s", e) + logger.debug("Failed to send telemetry with unified client: %s", e) raise def _telemetry_request_callback(self, future, sent_count: int): @@ -352,19 +356,22 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) - def _export_telemetry_log(self, **telemetry_event_kwargs): + def _export_telemetry_log(self, session_id=None, **telemetry_event_kwargs): """ Common helper method for exporting telemetry logs. Args: + session_id: Optional session ID for this event. If not provided, uses the client's session ID. **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor """ - logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) + # Use provided session_id or fall back to client's session_id + actual_session_id = session_id or self._session_id_hex + logger.debug("Exporting telemetry log for connection %s", actual_session_id) try: # Set common fields for all telemetry events event_kwargs = { - "session_id": self._session_id_hex, + "session_id": actual_session_id, "system_configuration": TelemetryHelper.get_driver_system_configuration(), "driver_connection_params": self._driver_connection_params, } @@ -387,17 +394,22 @@ def _export_telemetry_log(self, **telemetry_event_kwargs): except Exception as e: logger.debug("Failed to export telemetry log: %s", e) - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): self._driver_connection_params = driver_connection_params self._user_agent = user_agent - self._export_telemetry_log() + self._export_telemetry_log(session_id=session_id) - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) - self._export_telemetry_log(error_info=error_info) + self._export_telemetry_log(session_id=session_id, error_info=error_info) - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): self._export_telemetry_log( + session_id=session_id, sql_statement_id=sql_statement_id, sql_operation=sql_execution_event, operation_latency_ms=latency_ms, @@ -409,15 +421,39 @@ def close(self): self._flush() +class _TelemetryClientHolder: + """ + Holds a telemetry client with reference counting. + Multiple connections to the same host share one client. + """ + + def __init__(self, client: BaseTelemetryClient): + self.client = client + self.refcount = 1 + + def increment(self): + """Increment reference count when a new connection uses this client""" + self.refcount += 1 + + def decrement(self): + """Decrement reference count when a connection closes""" + self.refcount -= 1 + return self.refcount + + class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. + + Clients are shared at the HOST level - multiple connections to the same host + share a single TelemetryClient to enable efficient batching and reduce load + on the telemetry endpoint. """ _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + str, _TelemetryClientHolder + ] = {} # Map of host_url -> TelemetryClientHolder _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -431,6 +467,22 @@ class TelemetryClientFactory: _flush_interval_seconds = 300 # 5 minutes DEFAULT_BATCH_SIZE = 100 + UNKNOWN_HOST = "unknown-host" + + @staticmethod + def getHostUrlSafely(host_url): + """ + Safely get host URL with fallback to UNKNOWN_HOST. + + Args: + host_url: The host URL to validate + + Returns: + The host_url if valid, otherwise UNKNOWN_HOST + """ + if not host_url or not isinstance(host_url, str) or not host_url.strip(): + return TelemetryClientFactory.UNKNOWN_HOST + return host_url @classmethod def _initialize(cls): @@ -464,8 +516,8 @@ def _flush_worker(cls): with cls._lock: clients_to_flush = list(cls._clients.values()) - for client in clients_to_flush: - client._flush() + for holder in clients_to_flush: + holder.client._flush() @classmethod def _stop_flush_thread(cls): @@ -506,21 +558,38 @@ def initialize_telemetry_client( batch_size, client_context, ): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" + """ + Initialize a telemetry client for a specific connection if telemetry is enabled. + + Clients are shared at the HOST level - multiple connections to the same host + will share a single TelemetryClient with reference counting. + """ try: + # Safely get host_url with fallback to UNKNOWN_HOST + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() - if session_id_hex not in TelemetryClientFactory._clients: + if host_url in TelemetryClientFactory._clients: + # Reuse existing client for this host + holder = TelemetryClientFactory._clients[host_url] + holder.increment() logger.debug( - "Creating new TelemetryClient for connection %s", + "Reusing TelemetryClient for host %s (session %s, refcount=%d)", + host_url, + session_id_hex, + holder.refcount, + ) + else: + # Create new client for this host + logger.debug( + "Creating new TelemetryClient for host %s (session %s)", + host_url, session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[ - session_id_hex - ] = TelemetryClient( + client = TelemetryClient( telemetry_enabled=telemetry_enabled, session_id_hex=session_id_hex, auth_provider=auth_provider, @@ -529,36 +598,73 @@ def initialize_telemetry_client( batch_size=batch_size, client_context=client_context, ) + TelemetryClientFactory._clients[ + host_url + ] = _TelemetryClientHolder(client) else: TelemetryClientFactory._clients[ - session_id_hex - ] = NoopTelemetryClient() + host_url + ] = _TelemetryClientHolder(NoopTelemetryClient()) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder( + NoopTelemetryClient() + ) @staticmethod - def get_telemetry_client(session_id_hex): - """Get the telemetry client for a specific connection""" - return TelemetryClientFactory._clients.get( - session_id_hex, NoopTelemetryClient() - ) + def get_telemetry_client(host_url): + """ + Get the shared telemetry client for a specific host. + + Args: + host_url: The host URL to look up the client. If None/empty, uses UNKNOWN_HOST. + + Returns: + The shared TelemetryClient for this host, or NoopTelemetryClient if not found + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) + + if host_url in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[host_url].client + return NoopTelemetryClient() @staticmethod - def close(session_id_hex): - """Close and remove the telemetry client for a specific connection""" + def close(host_url): + """ + Close the telemetry client for a specific host. + + Decrements the reference count for the host's client. Only actually closes + the client when the reference count reaches zero (all connections to this host closed). + + Args: + host_url: The host URL whose client to close. If None/empty, uses UNKNOWN_HOST. + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: - if ( - telemetry_client := TelemetryClientFactory._clients.pop( - session_id_hex, None - ) - ) is not None: + # Get the holder for this host + holder = TelemetryClientFactory._clients.get(host_url) + if holder is None: + logger.debug("No telemetry client found for host %s", host_url) + return + + # Decrement refcount + remaining_refs = holder.decrement() + logger.debug( + "Decremented refcount for host %s (refcount=%d)", + host_url, + remaining_refs, + ) + + # Only close if no more references + if remaining_refs <= 0: logger.debug( - "Removing telemetry client for connection %s", session_id_hex + "Closing telemetry client for host %s (no more references)", + host_url, ) - telemetry_client.close() + TelemetryClientFactory._clients.pop(host_url, None) + holder.client.close() # Shutdown executor if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: @@ -597,7 +703,7 @@ def connection_failure_log( ) telemetry_client = TelemetryClientFactory.get_telemetry_client( - UNAUTH_DUMMY_SESSION_ID + host_url=host_url ) telemetry_client._driver_connection_params = DriverConnectionParameters( http_path=http_path, diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py index 461a57738..e77910007 100644 --- a/src/databricks/sql/telemetry/telemetry_push_client.py +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -120,7 +120,7 @@ def _make_request_and_check_status( # Check for rate limiting or service unavailable if response.status in [429, 503]: - logger.warning( + logger.debug( "Telemetry endpoint returned %d for host %s, triggering circuit breaker", response.status, self._host, diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index d2ac4227d..546a2b8b2 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -41,6 +41,7 @@ def telemetry_setup_teardown(self): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._clients.clear() TelemetryClientFactory._initialized = False def test_concurrent_queries_sends_telemetry(self): diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py index e8ed4e809..1e02556d9 100644 --- a/tests/unit/test_circuit_breaker_manager.py +++ b/tests/unit/test_circuit_breaker_manager.py @@ -157,4 +157,4 @@ def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: listener.state_change(mock_cb, mock_old_state, mock_new_state) - mock_logger.info.assert_called() + mock_logger.debug.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b515756e8..8f8a97eae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -714,7 +714,7 @@ def test_autocommit_setter_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -737,7 +737,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): mock_cursor = Mock() original_error = DatabaseError( - "Original error", session_id_hex="test-session-id" + "Original error", host_url="test-host" ) mock_cursor.execute.side_effect = original_error @@ -772,7 +772,7 @@ def test_commit_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -822,7 +822,7 @@ def test_rollback_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "Unexpected rollback error", context={"sql_state": "HY000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 96a2f87d8..e9fa16649 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -249,13 +249,13 @@ def test_client_lifecycle_flow(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex # Close client with patch.object(client, "close") as mock_close: - TelemetryClientFactory.close(session_id_hex) + TelemetryClientFactory.close(host_url="test-host.com") mock_close.assert_called_once() # Should get NoopTelemetryClient after close @@ -274,7 +274,7 @@ def test_disabled_telemetry_creates_noop_client(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): @@ -297,7 +297,7 @@ def test_factory_error_handling(self): ) # Should fall back to NoopTelemetryClient - client = TelemetryClientFactory.get_telemetry_client(session_id) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_shutdown_flow(self): @@ -325,11 +325,11 @@ def test_factory_shutdown_flow(self): assert TelemetryClientFactory._executor is not None # Close first client - factory should stay initialized - TelemetryClientFactory.close(session1) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is True # Close second client - factory should shut down - TelemetryClientFactory.close(session2) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is False assert TelemetryClientFactory._executor is None @@ -367,6 +367,13 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" + def teardown_method(self): + """Clean up telemetry factory state after each test to prevent test pollution.""" + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + + TelemetryClientFactory._clients.clear() + FeatureFlagsContextFactory._context_map.clear() + def _mock_ff_response(self, mock_http_request, enabled: bool): """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() @@ -391,6 +398,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -410,7 +418,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio assert conn.telemetry_enabled is True mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, TelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -421,6 +429,7 @@ def test_telemetry_disabled_when_flag_is_false( self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -440,7 +449,7 @@ def test_telemetry_disabled_when_flag_is_false( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -451,6 +460,7 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -470,7 +480,7 @@ def test_telemetry_disabled_when_flag_request_fails( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py index 0e9455e1f..6555f1d02 100644 --- a/tests/unit/test_telemetry_push_client.py +++ b/tests/unit/test_telemetry_push_client.py @@ -114,10 +114,10 @@ def test_rate_limit_error_logging(self): with pytest.raises(TelemetryRateLimitError): self.client.request(HttpMethod.POST, "https://test.com", {}) - mock_logger.warning.assert_called() - warning_args = mock_logger.warning.call_args[0] - assert "429" in str(warning_args) - assert "circuit breaker" in warning_args[0] + mock_logger.debug.assert_called() + debug_args = mock_logger.debug.call_args[0] + assert "429" in str(debug_args) + assert "circuit breaker" in debug_args[0] def test_other_error_logging(self): """Test that other errors are logged during wrapping/unwrapping."""