From 4437a2ac0b46b700a65fb8ad946c97b8f212c52e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 12:58:53 +0530 Subject: [PATCH 01/35] Refactor codebase to use a unified http client Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/authenticators.py | 2 + src/databricks/sql/auth/common.py | 61 +++-- src/databricks/sql/auth/oauth.py | 28 ++- src/databricks/sql/backend/sea/queue.py | 4 + src/databricks/sql/backend/sea/result_set.py | 1 + src/databricks/sql/client.py | 38 ++- .../sql/cloudfetch/download_manager.py | 3 + src/databricks/sql/cloudfetch/downloader.py | 79 +++--- src/databricks/sql/common/feature_flag.py | 16 +- src/databricks/sql/common/http.py | 112 --------- .../sql/common/unified_http_client.py | 226 ++++++++++++++++++ src/databricks/sql/result_set.py | 1 + src/databricks/sql/session.py | 39 ++- .../sql/telemetry/telemetry_client.py | 22 +- src/databricks/sql/utils.py | 15 +- 16 files changed, 440 insertions(+), 211 deletions(-) create mode 100644 src/databricks/sql/common/unified_http_client.py diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 3792d6d05..a8d0671b0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_client_id, cfg.oauth_scopes, cfg.auth_type, + http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client=http_client, ) else: raise RuntimeError("No valid authentication settings!") @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): ) -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs): # TODO : unify all the auth mechanisms with the Python SDK auth_type = kwargs.get("auth_type") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 26c1f3708..80f44812c 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -63,6 +63,7 @@ def __init__( redirect_port_range: List[int], client_id: str, scopes: List[str], + http_client, auth_type: str = "databricks-oauth", ): try: @@ -79,6 +80,7 @@ def __init__( port_range=redirect_port_range, client_id=client_id, idp_endpoint=idp_endpoint, + http_client=http_client, ) self._hostname = hostname self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 5cfbc37c0..262166a52 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,7 +2,6 @@ import logging from typing import Optional, List from urllib.parse import urlparse -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -36,6 +35,21 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + # HTTP client configuration parameters + ssl_options=None, # SSLOptions type + socket_timeout: Optional[float] = None, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, + retry_dangerous_codes: Optional[List[int]] = None, + http_proxy: Optional[str] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + user_agent: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -51,6 +65,22 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + + # HTTP client configuration + self.ssl_options = ssl_options + self.socket_timeout = socket_timeout + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 60.0 + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_delay_default = retry_delay_default or 5.0 + self.retry_dangerous_codes = retry_dangerous_codes or [] + self.http_proxy = http_proxy + self.proxy_username = proxy_username + self.proxy_password = proxy_password + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 + self.user_agent = user_agent def get_effective_azure_login_app_id(hostname) -> str: @@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str: return AzureAppId.PROD.value[1] -def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: +def get_azure_tenant_id_from_host(host: str, http_client) -> str: """ Load the Azure tenant ID from the Azure Databricks login page. @@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: the Azure login page, and the tenant ID is extracted from the redirect URL. """ - if http_client is None: - http_client = DatabricksHttpClient.get_instance() - login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp: - if resp.status_code // 100 != 3: + + with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + if resp.status // 100 != 3: raise ValueError( - f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}" + f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" ) - entra_id_endpoint = resp.headers.get("Location") + entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... - # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). - url = urlparse(entra_id_endpoint) - path_segments = url.path.split("/") - if len(path_segments) < 2: - raise ValueError(f"Invalid path in Location header: {url.path}") - return path_segments[1] + + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urlparse(entra_id_endpoint) + path_segments = url.path.split("/") + if len(path_segments) < 2: + raise ValueError(f"Invalid path in Location header: {url.path}") + return path_segments[1] diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index aa3184d88..0d67929a3 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -9,10 +9,8 @@ from typing import List, Optional import oauthlib.oauth2 -import requests from oauthlib.oauth2.rfc6749.errors import OAuth2Error -from requests.exceptions import RequestException -from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader +from databricks.sql.common.http import HttpMethod, HttpHeader from databricks.sql.common.http import OAuthResponse from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler from databricks.sql.auth.endpoint import OAuthEndpointCollection @@ -85,11 +83,13 @@ def __init__( port_range: List[int], client_id: str, idp_endpoint: OAuthEndpointCollection, + http_client, ): self.port_range = port_range self.client_id = client_id self.redirect_port = None self.idp_endpoint = idp_endpoint + self.http_client = http_client @staticmethod def __token_urlsafe(nbytes=32): @@ -103,8 +103,12 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth()) - except RequestException as e: + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request('GET', url=known_config_url) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) + except Exception as e: logger.error( f"Unable to fetch OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -122,7 +126,7 @@ def __fetch_well_known_config(self, hostname: str): raise RuntimeError(msg) try: return response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: logger.error( f"Unable to decode OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -209,10 +213,13 @@ def __send_token_request(token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", } - response = requests.post( - url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth() + # Use unified HTTP client + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request( + 'POST', url=token_request_url, body=data, headers=headers ) - return response.json() + # Convert urllib3 response to dict for compatibility + return json.loads(response.data.decode()) def __send_refresh_token_request(self, hostname, refresh_token): oauth_config = self.__fetch_well_known_config(hostname) @@ -320,6 +327,7 @@ def __init__( token_url, client_id, client_secret, + http_client, extra_params: dict = {}, ): self.client_id = client_id @@ -327,7 +335,7 @@ def __init__( self.token_url = token_url self.extra_params = extra_params self.token: Optional[Token] = None - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client def get_token(self) -> Token: if self.token is None or self.token.is_expired(): diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5bf..4a319c442 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -50,6 +50,7 @@ def build_queue( max_download_threads: int, sea_client: SeaDatabricksClient, lz4_compressed: bool, + http_client, ) -> ResultSetQueue: """ Factory method to build a result set queue for SEA backend. @@ -94,6 +95,7 @@ def build_queue( total_chunk_count=manifest.total_chunk_count, lz4_compressed=lz4_compressed, description=description, + http_client=http_client, ) raise ProgrammingError("Invalid result format") @@ -309,6 +311,7 @@ def __init__( sea_client: SeaDatabricksClient, statement_id: str, total_chunk_count: int, + http_client, lz4_compressed: bool = False, description: List[Tuple] = [], ): @@ -337,6 +340,7 @@ def __init__( # TODO: fix these arguments when telemetry is implemented in SEA session_id_hex=None, chunk_id=0, + http_client=http_client, ) logger.debug( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..17838ed81 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -64,6 +64,7 @@ def __init__( max_download_threads=sea_client.max_download_threads, sea_client=sea_client, lz4_compressed=execute_response.lz4_compressed, + http_client=connection.session.http_client, ) # Call parent constructor with common attributes diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 73ee0e03c..295be29dc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -6,7 +6,6 @@ import pyarrow except ImportError: pyarrow = None -import requests import json import os import decimal @@ -292,6 +291,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, + http_client=self.session.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( @@ -744,16 +744,20 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - + # HTTP status codes + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NO_CONTENT = 204 # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: @@ -783,7 +787,13 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = requests.get(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -802,7 +812,13 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" if not r.ok: raise OperationalError( diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698bed..27265720f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -25,6 +25,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] self.chunk_id = chunk_id @@ -47,6 +48,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self._http_client = http_client def get_next_downloaded_file( self, next_row_offset: int @@ -109,6 +111,7 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + http_client=self._http_client, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 1331fa203..ea375fbbb 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -2,10 +2,9 @@ from dataclasses import dataclass from typing import Optional -from requests.adapters import Retry import lz4.frame import time -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -16,16 +15,6 @@ # TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. # But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests -retryPolicy = Retry( - total=5, # max retry attempts - backoff_factor=1, # min delay, 1 second - # TODO: `backoff_max` is supported since `urllib3` v2.0.0, but we allow >= 1.26. - # The default value (120 seconds) used since v1.26 looks reasonable enough - # backoff_max=60, # max delay, 60 seconds - # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500, - # excluding 501 Not implemented - status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)], -) @dataclass @@ -73,11 +62,12 @@ def __init__( chunk_id: int, session_id_hex: Optional[str], statement_id: str, + http_client, ): self.settings = settings self.link = link self._ssl_options = ssl_options - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -104,50 +94,47 @@ def run(self) -> DownloadedFile: start_time = time.time() - with self._http_client.execute( - method=HttpMethod.GET, + with self._http_client.request_context( + method='GET', url=self.link.fileLink, timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` ) as response: - response.raise_for_status() - - # Save (and decompress if needed) the downloaded file - compressed_data = response.content - - # Log download metrics - download_duration = time.time() - start_time - self._log_download_metrics( - self.link.fileLink, len(compressed_data), download_duration - ) - - decompressed_data = ( - ResultSetDownloadHandler._decompress_data(compressed_data) - if self.settings.is_lz4_compressed - else compressed_data - ) + if response.status >= 400: + raise Exception(f"HTTP {response.status}: {response.data.decode()}") + compressed_data = response.data + + # Log download metrics + download_duration = time.time() - start_time + self._log_download_metrics( + self.link.fileLink, len(compressed_data), download_duration + ) - # The size of the downloaded file should match the size specified from TSparkArrowResultLink - if len(decompressed_data) != self.link.bytesNum: - logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) - ) + decompressed_data = ( + ResultSetDownloadHandler._decompress_data(compressed_data) + if self.settings.is_lz4_compressed + else compressed_data + ) + # The size of the downloaded file should match the size specified from TSparkArrowResultLink + if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + len(decompressed_data), self.link.bytesNum ) ) - return DownloadedFile( - decompressed_data, - self.link.startRowOffset, - self.link.rowCount, + logger.debug( + "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount ) + ) + + return DownloadedFile( + decompressed_data, + self.link.startRowOffset, + self.link.rowCount, + ) def _log_download_metrics( self, url: str, bytes_downloaded: int, duration_seconds: float diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 53add9253..8e7029805 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -1,6 +1,6 @@ +import json import threading import time -import requests from dataclasses import dataclass, field from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING @@ -49,7 +49,7 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): + def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): from databricks.sql import __version__ self._connection = connection @@ -65,6 +65,9 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) + + # Use the provided HTTP client + self._http_client = http_client def _is_refresh_needed(self) -> bool: """Checks if the cache is due for a proactive background refresh.""" @@ -105,9 +108,12 @@ def _refresh_flags(self): self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header - response = requests.get( - self._feature_flag_endpoint, headers=headers, timeout=30 + response = self._http_client.request( + 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 ) + # Add compatibility attributes for urllib3 response + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.status_code == 200: ff_response = FeatureFlagsResponse.from_dict(response.json()) @@ -159,7 +165,7 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor) + cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 0cd2919c0..cf76a5fba 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -38,115 +38,3 @@ class OAuthResponse: resource: str = "" access_token: str = "" refresh_token: str = "" - - -# Singleton class for common Http Client -class DatabricksHttpClient: - ## TODO: Unify all the http clients in the PySQL Connector - - _instance = None - _lock = threading.Lock() - - def __init__(self): - self.session = requests.Session() - adapter = HTTPAdapter( - pool_connections=5, - pool_maxsize=10, - max_retries=Retry(total=10, backoff_factor=0.1), - ) - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "DatabricksHttpClient": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = DatabricksHttpClient() - return cls._instance - - @contextmanager - def execute( - self, method: HttpMethod, url: str, **kwargs - ) -> Generator[requests.Response, None, None]: - logger.info("Executing HTTP request: %s with url: %s", method.value, url) - response = None - try: - response = self.session.request(method.value, url, **kwargs) - yield response - except Exception as e: - logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) - raise e - finally: - if response is not None: - response.close() - - def close(self): - self.session.close() - - -class TelemetryHTTPAdapter(HTTPAdapter): - """ - Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. - This ensures the retry timer is started and the command type is set correctly, - allowing the policy to manage its state for the duration of the request retries. - """ - - def send(self, request, **kwargs): - self.max_retries.command_type = CommandType.OTHER - self.max_retries.start_retry_timer() - return super().send(request, **kwargs) - - -class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector - """Singleton HTTP client for sending telemetry data.""" - - _instance: Optional["TelemetryHttpClient"] = None - _lock = threading.Lock() - - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 1.0 - TELEMETRY_RETRY_DELAY_MAX = 10.0 - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 - - def __init__(self): - """Initializes the session and mounts the custom retry adapter.""" - retry_policy = DatabricksRetryPolicy( - delay_min=self.TELEMETRY_RETRY_DELAY_MIN, - delay_max=self.TELEMETRY_RETRY_DELAY_MAX, - stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, - stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, - force_dangerous_codes=[], - ) - adapter = TelemetryHTTPAdapter(max_retries=retry_policy) - self.session = requests.Session() - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "TelemetryHttpClient": - """Get the singleton instance of the TelemetryHttpClient.""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - logger.debug("Initializing singleton TelemetryHttpClient") - cls._instance = TelemetryHttpClient() - return cls._instance - - def post(self, url: str, **kwargs) -> requests.Response: - """ - Executes a POST request using the configured session. - - This is a blocking call intended to be run in a background thread. - """ - logger.debug("Executing telemetry POST request to: %s", url) - return self.session.post(url, **kwargs) - - def close(self): - """Closes the underlying requests.Session.""" - logger.debug("Closing TelemetryHttpClient session.") - self.session.close() - # Clear the instance to allow for re-initialization if needed - with TelemetryHttpClient._lock: - TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py new file mode 100644 index 000000000..8c3be2bfd --- /dev/null +++ b/src/databricks/sql/common/unified_http_client.py @@ -0,0 +1,226 @@ +import logging +import ssl +import urllib.parse +from contextlib import contextmanager +from typing import Dict, Any, Optional, Generator, Union + +import urllib3 +from urllib3 import PoolManager, ProxyManager +from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError + +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.exc import RequestError + +logger = logging.getLogger(__name__) + + +class UnifiedHttpClient: + """ + Unified HTTP client for all Databricks SQL connector HTTP operations. + + This client uses urllib3 for robust HTTP communication with retry policies, + connection pooling, SSL support, and proxy support. It replaces the various + singleton HTTP clients and direct requests usage throughout the codebase. + """ + + def __init__(self, client_context): + """ + Initialize the unified HTTP client. + + Args: + client_context: ClientContext instance containing HTTP configuration + """ + self.config = client_context + self._pool_manager = None + self._setup_pool_manager() + + def _setup_pool_manager(self): + """Set up the urllib3 PoolManager with configuration from ClientContext.""" + + # SSL context setup + ssl_context = None + if self.config.ssl_options: + ssl_context = ssl.create_default_context() + + # Configure SSL verification + if not self.config.ssl_options.tls_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif not self.config.ssl_options.tls_verify_hostname: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Load custom CA file if specified + if self.config.ssl_options.tls_trusted_ca_file: + ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) + + # Load client certificate if specified + if (self.config.ssl_options.tls_client_cert_file and + self.config.ssl_options.tls_client_cert_key_file): + ssl_context.load_cert_chain( + self.config.ssl_options.tls_client_cert_file, + self.config.ssl_options.tls_client_cert_key_file, + self.config.ssl_options.tls_client_cert_key_password + ) + + # Create retry policy + retry_policy = DatabricksRetryPolicy( + delay_min=self.config.retry_delay_min, + delay_max=self.config.retry_delay_max, + stop_after_attempts_count=self.config.retry_stop_after_attempts_count, + stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, + delay_default=self.config.retry_delay_default, + force_dangerous_codes=self.config.retry_dangerous_codes, + ) + + # Common pool manager kwargs + pool_kwargs = { + 'num_pools': self.config.pool_connections, + 'maxsize': self.config.pool_maxsize, + 'retries': retry_policy, + 'timeout': urllib3.Timeout( + connect=self.config.socket_timeout, + read=self.config.socket_timeout + ) if self.config.socket_timeout else None, + 'ssl_context': ssl_context, + } + + # Create proxy or regular pool manager + if self.config.http_proxy: + proxy_headers = None + if self.config.proxy_username and self.config.proxy_password: + proxy_headers = make_headers( + proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" + ) + + self._pool_manager = ProxyManager( + self.config.http_proxy, + proxy_headers=proxy_headers, + **pool_kwargs + ) + else: + self._pool_manager = PoolManager(**pool_kwargs) + + def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """Prepare headers for the request, including User-Agent.""" + request_headers = {} + + if self.config.user_agent: + request_headers['User-Agent'] = self.config.user_agent + + if headers: + request_headers.update(headers) + + return request_headers + + @contextmanager + def request_context( + self, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> Generator[urllib3.HTTPResponse, None, None]: + """ + Context manager for making HTTP requests with proper resource cleanup. + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Yields: + urllib3.HTTPResponse: The HTTP response object + """ + logger.debug("Making %s request to %s", method, url) + + request_headers = self._prepare_headers(headers) + response = None + + try: + response = self._pool_manager.request( + method=method, + url=url, + headers=request_headers, + **kwargs + ) + yield response + except MaxRetryError as e: + logger.error("HTTP request failed after retries: %s", e) + raise RequestError(f"HTTP request failed: {e}") + except Exception as e: + logger.error("HTTP request error: %s", e) + raise RequestError(f"HTTP request error: {e}") + finally: + if response: + response.close() + + def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + """ + Make an HTTP request. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, etc.) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Returns: + urllib3.HTTPResponse: The HTTP response object with data pre-loaded + """ + with self.request_context(method, url, headers=headers, **kwargs) as response: + # Read the response data to ensure it's available after context exit + response._body = response.data + return response + + def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + """ + Upload a file using PUT method. + + Args: + url: URL to upload to + file_path: Path to the file to upload + headers: Optional headers + + Returns: + urllib3.HTTPResponse: The response from the server + """ + with open(file_path, 'rb') as file_obj: + return self.request('PUT', url, body=file_obj.read(), headers=headers) + + def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + """ + Download a file using GET method. + + Args: + url: URL to download from + file_path: Path where to save the downloaded file + headers: Optional headers + """ + response = self.request('GET', url, headers=headers) + with open(file_path, 'wb') as file_obj: + file_obj.write(response.data) + + def close(self): + """Close the underlying connection pools.""" + if self._pool_manager: + self._pool_manager.clear() + self._pool_manager = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +# Compatibility class to maintain requests-like interface for OAuth +class IgnoreNetrcAuth: + """ + Compatibility class for OAuth code that expects requests.auth.AuthBase interface. + This is a no-op auth handler since OAuth handles auth differently. + """ + def __call__(self, request): + return request \ No newline at end of file diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9feb6e924..77673db9a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,6 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, + http_client=connection.session.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f1bc35bee..d0c94b6ba 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -4,6 +4,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.auth.common import ClientContext from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME @@ -11,6 +12,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -42,10 +44,6 @@ def __init__( self.schema = schema self.http_path = http_path - self.auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -77,6 +75,15 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) + # Create HTTP client configuration and unified HTTP client + self.client_context = self._build_client_context(server_hostname, **kwargs) + self.http_client = UnifiedHttpClient(self.client_context) + + # Create auth provider with HTTP client context + self.auth_provider = get_python_sql_connector_auth_provider( + server_hostname, http_client=self.http_client, **kwargs + ) + self.backend = self._create_backend( server_hostname, http_path, @@ -88,6 +95,26 @@ def __init__( self.protocol_version = None + def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: + """Build ClientContext with HTTP configuration from kwargs.""" + return ClientContext( + hostname=server_hostname, + ssl_options=self.ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + http_proxy=kwargs.get("http_proxy"), + proxy_username=kwargs.get("proxy_username"), + proxy_password=kwargs.get("proxy_password"), + pool_connections=kwargs.get("pool_connections"), + pool_maxsize=kwargs.get("pool_maxsize"), + user_agent=self.useragent_header, + ) + def _create_backend( self, server_hostname: str, @@ -185,3 +212,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False + + # Close HTTP client if it exists + if hasattr(self, 'http_client') and self.http_client: + self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 55f06c8df..93cef3600 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -168,6 +168,7 @@ def __init__( host_url, executor, batch_size, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -180,7 +181,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = TelemetryHttpClient.get_instance() + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -228,19 +229,34 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") + + # Use unified HTTP client future = self._executor.submit( - self._http_client.post, + self._send_with_unified_client, url, data=request.to_json(), headers=headers, timeout=900, ) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) + def _send_with_unified_client(self, url, data, headers): + """Helper method to send telemetry using the unified HTTP client.""" + try: + response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.data else {} + return response + except Exception as e: + logger.error("Failed to send telemetry with unified client: %s", e) + raise + def _telemetry_request_callback(self, future, sent_count: int): """Callback function to handle telemetry request completion""" try: @@ -431,6 +447,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -453,6 +470,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, + http_client=http_client, ) else: TelemetryClientFactory._clients[ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c1d89ca5c..ff48e0e91 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -64,6 +64,7 @@ def build_queue( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, lz4_compressed: bool = True, description: List[Tuple] = [], ) -> ResultSetQueue: @@ -104,15 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, ) else: raise AssertionError("Row set type is not valid") @@ -224,6 +226,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -247,6 +250,7 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id self.chunk_id = chunk_id + self._http_client = http_client # Table state self.table = None @@ -261,6 +265,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -370,6 +375,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -396,6 +402,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) self.start_row_index = start_row_offset From 30c04a66c7abd88f455b57d78dd2ae230ff4b0cc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:04:13 +0530 Subject: [PATCH 02/35] Some more fixes and aligned tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/oauth.py | 18 -- src/databricks/sql/backend/thrift_backend.py | 10 +- src/databricks/sql/client.py | 48 +++++ src/databricks/sql/session.py | 27 +-- .../sql/telemetry/telemetry_client.py | 6 +- tests/unit/test_auth.py | 58 ++++-- tests/unit/test_cloud_fetch_queue.py | 183 ++++-------------- tests/unit/test_download_manager.py | 2 + tests/unit/test_downloader.py | 162 +++++++++------- tests/unit/test_telemetry.py | 73 +++++-- tests/unit/test_telemetry_retry.py | 88 ++++----- 12 files changed, 336 insertions(+), 343 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a8d0671b0..cc421e69e 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -10,7 +10,7 @@ from databricks.sql.auth.common import AuthType, ClientContext -def get_auth_provider(cfg: ClientContext): +def get_auth_provider(cfg: ClientContext, http_client): if cfg.credentials_provider: return ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: @@ -113,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), ) - return get_auth_provider(cfg) + return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 0d67929a3..270287953 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -61,22 +61,6 @@ def refresh(self) -> Token: pass -class IgnoreNetrcAuth(requests.auth.AuthBase): - """This auth method is a no-op. - - We use it to force requestslib to not use .netrc to write auth headers - when making .post() requests to the oauth token endpoints, since these - don't require authentication. - - In cases where .netrc is outdated or corrupt, these requests will fail. - - See issue #121 - """ - - def __call__(self, r): - return r - - class OAuthManager: def __init__( self, @@ -103,7 +87,6 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request('GET', url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status @@ -214,7 +197,6 @@ def __send_token_request(token_request_url, data): "Content-Type": "application/x-www-form-urlencoded", } # Use unified HTTP client - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request( 'POST', url=token_request_url, body=data, headers=headers ) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b1669..801632a41 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,6 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, + http_client=None, **kwargs, ): # Internal arguments in **kwargs: @@ -145,10 +146,8 @@ def __init__( # Number of threads for handling cloud fetch downloads. Defaults to 10 logger.debug( - "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", - server_hostname, - port, - http_path, + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)" + % (server_hostname, port, http_path) ) port = port or 443 @@ -177,8 +176,8 @@ def __init__( self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options - self._auth_provider = auth_provider + self._http_client = http_client # Connector version 3 retry approach self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) @@ -1292,6 +1291,7 @@ def fetch_results( session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), chunk_id=chunk_id, + http_client=self._http_client, ) return ( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 295be29dc..50f252dbc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -50,6 +50,9 @@ from databricks.sql.session import Session from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId +from databricks.sql.auth.common import ClientContext +from databricks.sql.common.unified_http_client import UnifiedHttpClient + from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TSparkParameter, @@ -251,10 +254,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) + client_context = self._build_client_context(server_hostname, **kwargs) + http_client = UnifiedHttpClient(client_context) + try: self.session = Session( server_hostname, http_path, + http_client, http_headers, session_configuration, catalog, @@ -270,6 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), + http_client=http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -342,6 +350,46 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value + def _build_client_context(self, server_hostname: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{__version__}" + + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_delay_min=kwargs.get("_retry_delay_min", 1.0), + retry_delay_max=kwargs.get("_retry_delay_max", 60.0), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_delay_default=kwargs.get("_retry_delay_default", 1.0), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections", 1), + pool_maxsize=kwargs.get("_pool_maxsize", 1), + user_agent=user_agent, + ) + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d0c94b6ba..c9b4f939a 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -22,6 +22,7 @@ def __init__( self, server_hostname: str, http_path: str, + http_client: UnifiedHttpClient, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -75,9 +76,8 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - # Create HTTP client configuration and unified HTTP client - self.client_context = self._build_client_context(server_hostname, **kwargs) - self.http_client = UnifiedHttpClient(self.client_context) + # Use the provided HTTP client (created in Connection) + self.http_client = http_client # Create auth provider with HTTP client context self.auth_provider = get_python_sql_connector_auth_provider( @@ -95,26 +95,6 @@ def __init__( self.protocol_version = None - def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: - """Build ClientContext with HTTP configuration from kwargs.""" - return ClientContext( - hostname=server_hostname, - ssl_options=self.ssl_options, - socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), - retry_delay_min=kwargs.get("_retry_delay_min"), - retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), - retry_delay_default=kwargs.get("_retry_delay_default"), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - http_proxy=kwargs.get("http_proxy"), - proxy_username=kwargs.get("proxy_username"), - proxy_password=kwargs.get("proxy_password"), - pool_connections=kwargs.get("pool_connections"), - pool_maxsize=kwargs.get("pool_maxsize"), - user_agent=self.useragent_header, - ) - def _create_backend( self, server_hostname: str, @@ -142,6 +122,7 @@ def _create_backend( "http_headers": all_headers, "auth_provider": auth_provider, "ssl_options": self.ssl_options, + "http_client": self.http_client, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 93cef3600..13c15486d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,7 +3,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, TYPE_CHECKING -from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -38,6 +37,8 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory +from src.databricks.sql.common.unified_http_client import UnifiedHttpClient + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -511,7 +512,6 @@ def close(session_id_hex): try: TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryHttpClient.close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None @@ -524,6 +524,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, + http_client: UnifiedHttpClient, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -536,6 +537,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 8bf914708..2e210a9e0 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -24,8 +24,8 @@ AzureOAuthEndpointCollection, ) from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache +import json class Auth(unittest.TestCase): @@ -98,12 +98,14 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): ) in params: with self.subTest(cloud_type.value): oauth_persistence = OAuthPersistenceCache() + mock_http_client = MagicMock() auth_provider = DatabricksOAuthProvider( hostname=host, oauth_persistence=oauth_persistence, redirect_port_range=[8020], client_id=client_id, scopes=scopes, + http_client=mock_http_client, auth_type=AuthType.AZURE_OAUTH.value if use_azure_auth else AuthType.DATABRICKS_OAUTH.value, @@ -142,7 +144,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -159,7 +162,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -174,7 +178,8 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_tls_client_cert_file": tls_client_cert_file, "_use_cert_as_auth": use_cert_as_auth, } - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -182,8 +187,9 @@ def test_get_python_sql_connector_basic_auth(self): "username": "username", "password": "password", } + mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -191,7 +197,8 @@ def test_get_python_sql_connector_basic_auth(self): @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" - auth_provider = get_python_sql_connector_auth_provider(hostname) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -223,10 +230,12 @@ def status_response(response_status_code): @pytest.fixture def token_source(self): + mock_http_client = MagicMock() return ClientCredentialsTokenSource( token_url="https://token_url.com", client_id="client_id", client_secret="client_secret", + http_client=mock_http_client, ) def test_no_token_refresh__when_token_is_not_expired( @@ -249,10 +258,21 @@ def test_no_token_refresh__when_token_is_not_expired( assert mock_get_token.call_count == 1 def test_get_token_success(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(200) - ) as mock_request: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with the expected format + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "abc123", + "token_type": "Bearer", + "refresh_token": None, + } + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + token = token_source.get_token() # Assert @@ -262,11 +282,19 @@ def test_get_token_success(self, token_source, http_response): assert token.refresh_token is None def test_get_token_failure(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(400) - ) as mock_request: - with pytest.raises(Exception) as e: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_response.json.return_value = {"error": "invalid_client"} + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + + with pytest.raises(Exception): token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f99..0c3fc7103 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,6 +13,31 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + """Helper method to create ThriftCloudFetchQueue with sensible defaults""" + # Set up defaults for commonly used parameters + defaults = { + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + mock_http_client = MagicMock() + return utils.ThriftCloudFetchQueue( + schema_bytes=schema_bytes or MagicMock(), + result_links=result_links or [], + description=description or [], + http_client=mock_http_client, + **defaults + ) + def create_result_link( self, file_link: str = "fileLink", @@ -58,15 +83,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=result_links) assert len(queue.download_manager._pending_links) == 10 assert len(queue.download_manager._download_tasks) == 0 @@ -74,16 +91,7 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() - result_links = [] - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=[]) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 @@ -94,15 +102,7 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=MagicMock(), result_links=[]) assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) @@ -117,16 +117,7 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) expected_result = self.make_arrow_table() mock_get_next_downloaded_file.assert_called_with(0) @@ -145,16 +136,7 @@ def test_initializer_create_next_table_success( def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -167,16 +149,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -190,16 +163,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -218,16 +182,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -242,17 +197,9 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): ) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.next_n_rows(100) @@ -263,16 +210,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -285,16 +223,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -307,16 +236,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -335,16 +255,7 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -365,17 +276,9 @@ def test_remaining_rows_multiple_tables_fully_returned( ) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.remaining_rows() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 6eb17a05a..1c77226a9 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -14,6 +14,7 @@ class DownloadManagerTests(unittest.TestCase): def create_download_manager( self, links, max_download_threads=10, lz4_compressed=True ): + mock_http_client = MagicMock() return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -22,6 +23,7 @@ def create_download_manager( session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + http_client=mock_http_client, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c514980ee..00b1b849a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,21 +1,19 @@ -from contextlib import contextmanager import unittest -from unittest.mock import Mock, patch, MagicMock - +from unittest.mock import patch, MagicMock, Mock import requests import databricks.sql.cloudfetch.downloader as downloader -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions -def create_response(**kwargs) -> requests.Response: - result = requests.Response() +def create_mock_response(**kwargs): + """Create a mock response object for testing""" + mock_response = MagicMock() for k, v in kwargs.items(): - setattr(result, k, v) - result.close = Mock() - return result + setattr(mock_response, k, v) + mock_response.close = Mock() + return mock_response class DownloaderTests(unittest.TestCase): @@ -23,6 +21,17 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ + def _setup_mock_http_response(self, mock_http_client, status=200, data=b""): + """Helper method to setup mock HTTP client with response context manager.""" + mock_response = MagicMock() + mock_response.status = status + mock_response.data = data + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_response + mock_context_manager.__exit__.return_value = None + mock_http_client.request_context.return_value = mock_context_manager + return mock_response + def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] @@ -38,6 +47,7 @@ def time_side_effect(): @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): + mock_http_client = MagicMock() settings = Mock() result_link = Mock() # Already expired @@ -49,6 +59,7 @@ def test_run_link_expired(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -59,6 +70,7 @@ def test_run_link_expired(self, mock_time): @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time @@ -70,6 +82,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -80,46 +93,45 @@ def test_run_link_past_expiry_buffer(self, mock_time): @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=404, data=b"1234") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(Exception) as context: + d.run() + self.assertTrue("404" in str(context.exception)) @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) - result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123" + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=file_bytes), - ): + # Patch the log metrics method to avoid division by zero + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -127,29 +139,32 @@ def test_run_uncompressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=compressed_bytes), - ): + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + + # Mock the decompression method and log metrics to avoid issues + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -157,48 +172,53 @@ def test_run_compressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time): - - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(ConnectionError): - d.run() + mock_http_client.request_context.side_effect = ConnectionError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(ConnectionError): + d.run() @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(TimeoutError): - d.run() + mock_http_client.request_context.side_effect = TimeoutError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(TimeoutError): + d.run() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d85e41719..989b2351c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,6 +1,7 @@ import uuid import pytest from unittest.mock import patch, MagicMock +import json from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -23,6 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -31,6 +33,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) @@ -72,10 +75,15 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - @patch("requests.post") - def test_network_request_flow(self, mock_post, mock_telemetry_client): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_network_request_flow(self, mock_http_request, mock_telemetry_client): """Test the complete network request flow with authentication.""" - mock_post.return_value.status_code = 200 + # Mock response for unified HTTP client + mock_response = MagicMock() + mock_response.status = 200 + mock_response.status_code = 200 + mock_http_request.return_value = mock_response + client = mock_telemetry_client # Create mock events @@ -91,7 +99,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == client._http_client.post + assert args[0] == client._send_with_unified_client assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token" @@ -208,6 +216,7 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -216,6 +225,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -234,6 +244,7 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -241,6 +252,7 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -249,6 +261,7 @@ def test_disabled_telemetry_flow(self): def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -261,6 +274,7 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient @@ -271,6 +285,7 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -280,6 +295,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Factory should be initialized @@ -325,10 +341,11 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" - def _mock_ff_response(self, mock_requests_get, enabled: bool): - """Helper to configure the mock response for the feature flag endpoint.""" + def _mock_ff_response(self, mock_http_request, enabled: bool): + """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() - mock_response.status_code = 200 + mock_response.status = 200 + mock_response.status_code = 200 # Compatibility attribute payload = { "flags": [ { @@ -339,15 +356,21 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool): "ttl_seconds": 3600, } mock_response.json.return_value = payload - mock_requests_get.return_value = mock_response + mock_response.data = json.dumps(payload).encode() + mock_http_request.return_value = mock_response - @patch("databricks.sql.common.feature_flag.requests.get") - def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSession): """Telemetry should be ON when enable_telemetry=True and server flag is 'true'.""" - self._mock_ff_response(mock_requests_get, enabled=True) + self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -357,19 +380,24 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSessio ) assert conn.telemetry_enabled is True - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") assert isinstance(client, TelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_is_false( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should be OFF when enable_telemetry=True but server flag is 'false'.""" - self._mock_ff_response(mock_requests_get, enabled=False) + self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -379,19 +407,24 @@ def test_telemetry_disabled_when_flag_is_false( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_request_fails( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should default to OFF if the feature flag network request fails.""" - mock_requests_get.side_effect = Exception("Network is down") + mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -401,6 +434,6 @@ def test_telemetry_disabled_when_flag_request_fails( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index d5287deb9..f0bdddd60 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,27 +6,23 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" +PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" -def create_mock_conn(responses): - """Creates a mock connection object whose getresponse() method yields a series of responses.""" - mock_conn = MagicMock() - mock_http_responses = [] +def create_mock_response(responses): + """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" + mock_responses = [] for resp in responses: - mock_http_response = MagicMock() - mock_http_response.status = resp.get("status") - mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b"{}") - mock_http_response.fp = io.BytesIO(body) - - def release(): - mock_http_response.fp.close() - - mock_http_response.release_conn = release - mock_http_responses.append(mock_http_response) - mock_conn.getresponse.side_effect = mock_http_responses - return mock_conn + mock_response = MagicMock() + mock_response.status = resp.get("status") + mock_response.status_code = resp.get("status") # Add status_code for compatibility + mock_response.headers = resp.get("headers", {}) + mock_response.data = resp.get("body", b"{}") + mock_response.ok = resp.get("status", 200) < 400 + mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) + mock_response.json = lambda: {} # Simple json mock + mock_responses.append(mock_response) + return mock_responses class TestTelemetryClientRetries: @@ -43,30 +39,16 @@ def setup_and_teardown(self): TelemetryClientFactory._executor = None def get_client(self, session_id, num_retries=3): - """ - Configures a client with a specific number of retries. - """ + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="test.databricks.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id) - - retry_policy = DatabricksRetryPolicy( - delay_min=0.01, - delay_max=0.02, - stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, - delay_default=0.1, - force_dangerous_codes=[], - urllib3_kwargs={"total": num_retries}, + batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests + http_client=mock_http_client, ) - adapter = client._http_client.session.adapters.get("https://") - adapter.max_retries = retry_policy - return client + return TelemetryClientFactory.get_telemetry_client(session_id) @pytest.mark.parametrize( "status_code, description", @@ -85,13 +67,19 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response = create_mock_response(mock_responses)[0] + with patch(PATCH_TARGET, return_value=mock_response) as mock_request: client.export_failure_log("TestError", "Test message") + + # Wait a moment for async operations to complete + time.sleep(0.1) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait a bit more for any final operations + time.sleep(0.1) - mock_get_conn.return_value.getresponse.assert_called_once() + mock_request.assert_called_once() def test_exceeds_retry_count_limit(self): """ @@ -103,22 +91,28 @@ def test_exceeds_retry_count_limit(self): retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) mock_responses = [ - {"status": 503, "headers": {"Retry-After": str(retry_after)}}, - {"status": 429}, + {"status": 429, "headers": {"Retry-After": str(retry_after)}}, {"status": 502}, {"status": 503}, + {"status": 200}, ] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response_objects = create_mock_response(mock_responses) + with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: start_time = time.time() client.export_failure_log("TestError", "Test message") + + # Wait for async operations to complete + time.sleep(0.2) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait for any final operations + time.sleep(0.2) + end_time = time.time() assert ( - mock_get_conn.return_value.getresponse.call_count + mock_request.call_count == expected_total_calls ) - assert end_time - start_time > retry_after From 429460082749de360c9e86e55772f093deeca05e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:25:27 +0530 Subject: [PATCH 03/35] Fix all tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/client.py | 2 +- tests/unit/test_auth.py | 2 +- tests/unit/test_sea_queue.py | 23 +++++- tests/unit/test_session.py | 3 +- tests/unit/test_telemetry_retry.py | 118 ----------------------------- 5 files changed, 23 insertions(+), 125 deletions(-) delete mode 100644 tests/unit/test_telemetry_retry.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 50f252dbc..7323b939a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -443,7 +443,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.is_open + return hasattr(self, 'session') and self.session.is_open def cursor( self, diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 2e210a9e0..333782fd8 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -294,7 +294,7 @@ def test_get_token_failure(self, token_source, http_response): mock_http_client.execute.return_value.__enter__.return_value = mock_response mock_http_client.execute.return_value.__exit__.return_value = None - with pytest.raises(Exception): + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098b..6471cb4fd 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -7,7 +7,7 @@ """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from databricks.sql.backend.sea.queue import ( JsonQueue, @@ -184,6 +184,7 @@ def description(self): def test_build_queue_json_array(self, json_manifest, sample_data): """Test building a JSON array queue.""" result_data = ResultData(data=sample_data) + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -194,6 +195,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, JsonQueue) @@ -217,6 +219,8 @@ def test_build_queue_arrow_stream( ] result_data = ResultData(data=None, external_links=external_links) + mock_http_client = MagicMock() + with patch( "databricks.sql.backend.sea.queue.ResultFileDownloadManager" ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): @@ -229,6 +233,7 @@ def test_build_queue_arrow_stream( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, SeaCloudFetchQueue) @@ -236,6 +241,7 @@ def test_build_queue_arrow_stream( def test_build_queue_invalid_format(self, invalid_manifest): """Test building a queue with invalid format.""" result_data = ResultData(data=[]) + mock_http_client = MagicMock() with pytest.raises(ProgrammingError, match="Invalid result format"): SeaResultSetQueueFactory.build_queue( @@ -247,6 +253,7 @@ def test_build_queue_invalid_format(self, invalid_manifest): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) @@ -339,6 +346,7 @@ def test_init_with_valid_initial_link( ): """Test initialization with valid initial link.""" # Create a queue with valid initial link + mock_http_client = MagicMock() with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[sample_external_link]), @@ -349,6 +357,7 @@ def test_init_with_valid_initial_link( total_chunk_count=1, lz4_compressed=False, description=description, + http_client=mock_http_client, ) # Verify attributes @@ -367,6 +376,7 @@ def test_init_no_initial_links( ): """Test initialization with no initial links.""" # Create a queue with empty initial links + mock_http_client = MagicMock() queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[]), max_download_threads=5, @@ -376,6 +386,7 @@ def test_init_no_initial_links( total_chunk_count=0, lz4_compressed=False, description=description, + http_client=mock_http_client, ) assert queue.table is None @@ -462,7 +473,7 @@ def test_hybrid_disposition_with_attachment( # Create result data with attachment attachment_data = b"mock_arrow_data" result_data = ResultData(attachment=attachment_data) - + mock_http_client = MagicMock() # Build queue queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -473,6 +484,7 @@ def test_hybrid_disposition_with_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify ArrowQueue was created @@ -508,7 +520,8 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -518,6 +531,7 @@ def test_hybrid_disposition_with_external_links( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify SeaCloudFetchQueue was created @@ -548,7 +562,7 @@ def test_hybrid_disposition_with_compressed_attachment( # Create result data with attachment result_data = ResultData(attachment=compressed_data) - + mock_http_client = MagicMock() # Build queue with lz4_compressed=True queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -559,6 +573,7 @@ def test_hybrid_disposition_with_compressed_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=True, + http_client=mock_http_client, ) # Verify ArrowQueue was created with decompressed data diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6823b1b33..e019e05a2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -75,8 +75,9 @@ def test_http_header_passthrough(self, mock_client_class): call_kwargs = mock_client_class.call_args[1] assert ("foo", "bar") in call_kwargs["http_headers"] + @patch("%s.client.UnifiedHttpClient" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): + def test_tls_arg_passthrough(self, mock_client_class, mock_http_client): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, _tls_verify_hostname="hostname", diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py deleted file mode 100644 index f0bdddd60..000000000 --- a/tests/unit/test_telemetry_retry.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import io -import time - -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.auth.retry import DatabricksRetryPolicy - -PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" - - -def create_mock_response(responses): - """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" - mock_responses = [] - for resp in responses: - mock_response = MagicMock() - mock_response.status = resp.get("status") - mock_response.status_code = resp.get("status") # Add status_code for compatibility - mock_response.headers = resp.get("headers", {}) - mock_response.data = resp.get("body", b"{}") - mock_response.ok = resp.get("status", 200) < 400 - mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) - mock_response.json = lambda: {} # Simple json mock - mock_responses.append(mock_response) - return mock_responses - - -class TestTelemetryClientRetries: - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - yield - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - - def get_client(self, session_id, num_retries=3): - mock_http_client = MagicMock() - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id, - auth_provider=None, - host_url="test.databricks.com", - batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests - http_client=mock_http_client, - ) - return TelemetryClientFactory.get_telemetry_client(session_id) - - @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], - ) - def test_non_retryable_status_codes_are_not_retried(self, status_code, description): - """ - Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. - """ - # Use the status code in the session ID for easier debugging if it fails - client = self.get_client(f"session-{status_code}") - mock_responses = [{"status": status_code}] - - mock_response = create_mock_response(mock_responses)[0] - with patch(PATCH_TARGET, return_value=mock_response) as mock_request: - client.export_failure_log("TestError", "Test message") - - # Wait a moment for async operations to complete - time.sleep(0.1) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait a bit more for any final operations - time.sleep(0.1) - - mock_request.assert_called_once() - - def test_exceeds_retry_count_limit(self): - """ - Verifies that the client retries up to the specified number of times before giving up. - Verifies that the client respects the Retry-After header and retries on 429, 502, 503. - """ - num_retries = 3 - expected_total_calls = num_retries + 1 - retry_after = 1 - client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [ - {"status": 429, "headers": {"Retry-After": str(retry_after)}}, - {"status": 502}, - {"status": 503}, - {"status": 200}, - ] - - mock_response_objects = create_mock_response(mock_responses) - with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: - start_time = time.time() - client.export_failure_log("TestError", "Test message") - - # Wait for async operations to complete - time.sleep(0.2) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait for any final operations - time.sleep(0.2) - - end_time = time.time() - - assert ( - mock_request.call_count - == expected_total_calls - ) From 31552117d01160d59980a201a5c47d7135eb4040 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:27:20 +0530 Subject: [PATCH 04/35] fmt Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 12 +- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 34 ++++-- src/databricks/sql/cloudfetch/downloader.py | 4 +- src/databricks/sql/common/feature_flag.py | 12 +- .../sql/common/unified_http_client.py | 109 +++++++++--------- src/databricks/sql/session.py | 4 +- .../sql/telemetry/telemetry_client.py | 12 +- 8 files changed, 108 insertions(+), 83 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 262166a52..61b07cb91 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -65,14 +65,16 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider - + # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy @@ -110,8 +112,8 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - - with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + + with http_client.request_context("GET", login_url, allow_redirects=False) as resp: if resp.status // 100 != 3: raise ValueError( f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" @@ -119,7 +121,7 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). url = urlparse(entra_id_endpoint) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 270287953..7f96a2303 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = self.http_client.request('GET', url=known_config_url) + response = self.http_client.request("GET", url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status response.json = lambda: json.loads(response.data.decode()) @@ -198,7 +198,7 @@ def __send_token_request(token_request_url, data): } # Use unified HTTP client response = self.http_client.request( - 'POST', url=token_request_url, body=data, headers=headers + "POST", url=token_request_url, body=data, headers=headers ) # Convert urllib3 response to dict for compatibility return json.loads(response.data.decode()) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 7323b939a..1a35f97da 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -354,7 +354,7 @@ def _build_client_context(self, server_hostname: str, **kwargs): """Build ClientContext for HTTP client configuration.""" from databricks.sql.auth.common import ClientContext from databricks.sql.types import SSLOptions - + # Extract SSL options ssl_options = SSLOptions( tls_verify=not kwargs.get("_tls_no_verify", False), @@ -364,22 +364,26 @@ def _build_client_context(self, server_hostname: str, **kwargs): tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - + # Build user agent user_agent_entry = kwargs.get("user_agent_entry", "") if user_agent_entry: user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" else: user_agent = f"PyDatabricksSqlConnector/{__version__}" - + return ClientContext( hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count", 30 + ), retry_delay_min=kwargs.get("_retry_delay_min", 1.0), retry_delay_max=kwargs.get("_retry_delay_max", 60.0), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ), retry_delay_default=kwargs.get("_retry_delay_default", 1.0), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), http_proxy=kwargs.get("_http_proxy"), @@ -443,7 +447,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return hasattr(self, 'session') and self.session.is_open + return hasattr(self, "session") and self.session.is_open def cursor( self, @@ -792,10 +796,12 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + r = self.connection.session.http_client.request( + "PUT", presigned_url, body=fh.read(), headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -835,10 +841,12 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "GET", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -860,10 +868,12 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "DELETE", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index ea375fbbb..cef4ca274 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -95,10 +95,10 @@ def run(self) -> DownloadedFile: start_time = time.time() with self._http_client.request_context( - method='GET', + method="GET", url=self.link.fileLink, timeout=self.settings.download_timeout, - headers=self.link.httpHeaders + headers=self.link.httpHeaders, ) as response: if response.status >= 400: raise Exception(f"HTTP {response.status}: {response.data.decode()}") diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8e7029805..1b920b008 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -49,7 +49,9 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): + def __init__( + self, connection: "Connection", executor: ThreadPoolExecutor, http_client + ): from databricks.sql import __version__ self._connection = connection @@ -65,7 +67,7 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_ self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) - + # Use the provided HTTP client self._http_client = http_client @@ -109,7 +111,7 @@ def _refresh_flags(self): headers["User-Agent"] = self._connection.session.useragent_header response = self._http_client.request( - 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 + "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) # Add compatibility attributes for urllib3 response response.status_code = response.status @@ -165,7 +167,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) + cls._context_map[key] = FeatureFlagsContext( + connection, cls._executor, connection.session.http_client + ) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 8c3be2bfd..a296704b4 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -18,7 +18,7 @@ class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. - + This client uses urllib3 for robust HTTP communication with retry policies, connection pooling, SSL support, and proxy support. It replaces the various singleton HTTP clients and direct requests usage throughout the codebase. @@ -37,12 +37,12 @@ def __init__(self, client_context): def _setup_pool_manager(self): """Set up the urllib3 PoolManager with configuration from ClientContext.""" - + # SSL context setup ssl_context = None if self.config.ssl_options: ssl_context = ssl.create_default_context() - + # Configure SSL verification if not self.config.ssl_options.tls_verify: ssl_context.check_hostname = False @@ -50,18 +50,22 @@ def _setup_pool_manager(self): elif not self.config.ssl_options.tls_verify_hostname: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_REQUIRED - + # Load custom CA file if specified if self.config.ssl_options.tls_trusted_ca_file: - ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) - + ssl_context.load_verify_locations( + self.config.ssl_options.tls_trusted_ca_file + ) + # Load client certificate if specified - if (self.config.ssl_options.tls_client_cert_file and - self.config.ssl_options.tls_client_cert_key_file): + if ( + self.config.ssl_options.tls_client_cert_file + and self.config.ssl_options.tls_client_cert_key_file + ): ssl_context.load_cert_chain( self.config.ssl_options.tls_client_cert_file, self.config.ssl_options.tls_client_cert_key_file, - self.config.ssl_options.tls_client_cert_key_password + self.config.ssl_options.tls_client_cert_key_password, ) # Create retry policy @@ -76,14 +80,15 @@ def _setup_pool_manager(self): # Common pool manager kwargs pool_kwargs = { - 'num_pools': self.config.pool_connections, - 'maxsize': self.config.pool_maxsize, - 'retries': retry_policy, - 'timeout': urllib3.Timeout( - connect=self.config.socket_timeout, - read=self.config.socket_timeout - ) if self.config.socket_timeout else None, - 'ssl_context': ssl_context, + "num_pools": self.config.pool_connections, + "maxsize": self.config.pool_maxsize, + "retries": retry_policy, + "timeout": urllib3.Timeout( + connect=self.config.socket_timeout, read=self.config.socket_timeout + ) + if self.config.socket_timeout + else None, + "ssl_context": ssl_context, } # Create proxy or regular pool manager @@ -93,58 +98,51 @@ def _setup_pool_manager(self): proxy_headers = make_headers( proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" ) - + self._pool_manager = ProxyManager( - self.config.http_proxy, - proxy_headers=proxy_headers, - **pool_kwargs + self.config.http_proxy, proxy_headers=proxy_headers, **pool_kwargs ) else: self._pool_manager = PoolManager(**pool_kwargs) - def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + def _prepare_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: """Prepare headers for the request, including User-Agent.""" request_headers = {} - + if self.config.user_agent: - request_headers['User-Agent'] = self.config.user_agent - + request_headers["User-Agent"] = self.config.user_agent + if headers: request_headers.update(headers) - + return request_headers @contextmanager def request_context( - self, - method: str, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs ) -> Generator[urllib3.HTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. - + Args: method: HTTP method (GET, POST, PUT, DELETE) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Yields: urllib3.HTTPResponse: The HTTP response object """ logger.debug("Making %s request to %s", method, url) - + request_headers = self._prepare_headers(headers) response = None - + try: response = self._pool_manager.request( - method=method, - url=url, - headers=request_headers, - **kwargs + method=method, url=url, headers=request_headers, **kwargs ) yield response except MaxRetryError as e: @@ -157,16 +155,18 @@ def request_context( if response: response.close() - def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + def request( + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> urllib3.HTTPResponse: """ Make an HTTP request. - + Args: method: HTTP method (GET, POST, PUT, DELETE, etc.) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Returns: urllib3.HTTPResponse: The HTTP response object with data pre-loaded """ @@ -175,32 +175,36 @@ def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = Non response._body = response.data return response - def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + def upload_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> urllib3.HTTPResponse: """ Upload a file using PUT method. - + Args: url: URL to upload to file_path: Path to the file to upload headers: Optional headers - + Returns: urllib3.HTTPResponse: The response from the server """ - with open(file_path, 'rb') as file_obj: - return self.request('PUT', url, body=file_obj.read(), headers=headers) + with open(file_path, "rb") as file_obj: + return self.request("PUT", url, body=file_obj.read(), headers=headers) - def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + def download_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> None: """ Download a file using GET method. - + Args: url: URL to download from file_path: Path where to save the downloaded file headers: Optional headers """ - response = self.request('GET', url, headers=headers) - with open(file_path, 'wb') as file_obj: + response = self.request("GET", url, headers=headers) + with open(file_path, "wb") as file_obj: file_obj.write(response.data) def close(self): @@ -222,5 +226,6 @@ class IgnoreNetrcAuth: Compatibility class for OAuth code that expects requests.auth.AuthBase interface. This is a no-op auth handler since OAuth handles auth differently. """ + def __call__(self, request): - return request \ No newline at end of file + return request diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index c9b4f939a..0cba8be48 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -193,7 +193,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False - + # Close HTTP client if it exists - if hasattr(self, 'http_client') and self.http_client: + if hasattr(self, "http_client") and self.http_client: self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 13c15486d..2785d3cca 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -230,7 +230,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") - + # Use unified HTTP client future = self._executor.submit( self._send_with_unified_client, @@ -239,7 +239,7 @@ def _send_telemetry(self, events): headers=headers, timeout=900, ) - + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -249,10 +249,14 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + response = self._http_client.request( + "POST", url, body=data, headers=headers, timeout=900 + ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status - response.json = lambda: json.loads(response.data.decode()) if response.data else {} + response.json = ( + lambda: json.loads(response.data.decode()) if response.data else {} + ) return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) From 1143838ad277f4a0309fdb40cdf682ad8e98ad9a Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sat, 9 Aug 2025 23:45:22 +0530 Subject: [PATCH 05/35] preliminary connection closure func --- src/databricks/sql/auth/thrift_http_client.py | 7 ++++--- src/databricks/sql/backend/sea/backend.py | 21 ++++++++++++++++--- .../sql/backend/sea/utils/http_client.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 9 ++++---- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f0daae162..a60540712 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -105,7 +105,6 @@ def startRetryTimer(self): self.retry_policy and self.retry_policy.start_retry_timer() def open(self): - # self.__pool replaces the self.__http used by the original THttpClient _pool_kwargs = {"maxsize": self.max_connections} @@ -140,11 +139,14 @@ def open(self): else: self.__pool = pool_class(self.host, self.port, **_pool_kwargs) - def close(self): + def release_connection(self): self.__resp and self.__resp.drain_conn() self.__resp and self.__resp.release_conn() self.__resp = None + def close(self): + self.__pool.close() + def read(self, sz): return self.__resp.read(sz) @@ -152,7 +154,6 @@ def isOpen(self): return self.__resp is not None def flush(self): - # Pull data out of buffer that will be sent in this request data = self.__wbuf.getvalue() self.__wbuf = BytesIO() diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..68f41084c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -273,7 +273,7 @@ def open_session( return SessionId.from_sea_session_id(session_id) - def close_session(self, session_id: SessionId) -> None: + def _close_session(self, session_id: SessionId) -> None: """ Closes an existing session with the Databricks SQL service. @@ -285,8 +285,6 @@ def close_session(self, session_id: SessionId) -> None: OperationalError: If there's an error closing the session """ - logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) - if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -302,6 +300,23 @@ def close_session(self, session_id: SessionId) -> None: data=request_data.to_dict(), ) + def close_session(self, session_id: SessionId) -> None: + """ + Closes the session and the underlying HTTP client. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) + + self._close_session(session_id) + self._http_client.close() + def _extract_description_from_manifest( self, manifest: ResultManifest ) -> List[Tuple]: diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index ef9a14353..f0aec2b2d 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -197,7 +197,7 @@ def _open(self): def close(self): """Close the connection pool.""" if self._pool: - self._pool.clear() + self._pool.close() def using_proxy(self) -> bool: """Check if proxy is being used.""" diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b1669..7598f8291 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.close() + self._transport.release_connection() raise self._request_lock = threading.RLock() @@ -478,7 +478,7 @@ def attempt_request(attempt): ) finally: # Calling `close()` here releases the active HTTP connection back to the pool - self._transport.close() + self._transport.release_connection() return RequestErrorInfo( error=error, @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.close() + self._transport.release_connection() raise def close_session(self, session_id: SessionId) -> None: @@ -619,7 +619,8 @@ def close_session(self, session_id: SessionId) -> None: try: self.make_request(self._client.CloseSession, req) finally: - self._transport.close() + self._transport.release_connection() + self._transport.close() def _check_command_not_in_error_or_closed_state( self, op_handle, get_operations_resp From 68cc8221457f75ac22b1ac6d877d9982f8dae2e5 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sat, 9 Aug 2025 23:49:26 +0530 Subject: [PATCH 06/35] unit test for backend closure --- tests/unit/test_thrift_backend.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0cdb43f5c..c671e4900 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1406,8 +1406,12 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_session_handle_respected_in_close_session(self, tcli_service_class): + @patch("databricks.sql.auth.thrift_http_client.THttpClient", autospec=True) + def test_session_handle_respected_in_close_session( + self, mock_http_client_class, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value + mock_http_client_instance = mock_http_client_class.return_value thrift_backend = ThriftDatabricksClient( "foobar", 443, @@ -1416,12 +1420,16 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) + # Manually set the mocked transport instance + thrift_backend._transport = mock_http_client_instance + session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) + mock_http_client_instance.close.assert_called_once() @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( From ef1d9fd0fdfc2af5387786d24372726f69866091 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 00:11:38 +0530 Subject: [PATCH 07/35] remove redundant comment --- tests/unit/test_thrift_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index c671e4900..8e1a0065a 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1420,7 +1420,6 @@ def test_session_handle_respected_in_close_session( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - # Manually set the mocked transport instance thrift_backend._transport = mock_http_client_instance session_id = SessionId.from_thrift_handle(self.session_handle) From 4bb2e4b0fb4238612d9b1f5b8401706a5295113c Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 10:36:18 +0530 Subject: [PATCH 08/35] assert SEA http client closure in unit tests --- tests/unit/test_sea_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..1e8da7d34 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -220,6 +220,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) + mock_http_client.close.assert_called_once() # Test close_session with invalid ID type with pytest.raises(ValueError) as excinfo: From 734dd06e131b274b6af7e5fab58dc4fabaa4902f Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Sun, 10 Aug 2025 15:56:43 +0530 Subject: [PATCH 09/35] correct docstrng --- src/databricks/sql/backend/thrift_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 7598f8291..ee6ed547e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -477,7 +477,7 @@ def attempt_request(attempt): ) ) finally: - # Calling `close()` here releases the active HTTP connection back to the pool + # Calling `release_connection()` here releases the active HTTP connection back to the pool self._transport.release_connection() return RequestErrorInfo( From d00e3c86a798bc8c7f89c9cc59fcae20ef7eff61 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 12:02:32 +0530 Subject: [PATCH 10/35] fix e2e Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 6 ++--- src/databricks/sql/auth/retry.py | 6 ++--- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/client.py | 20 +++++++--------- .../sql/common/unified_http_client.py | 24 ++++++++++++++++--- tests/e2e/common/retry_test_mixins.py | 5 +++- tests/e2e/common/staging_ingestion_tests.py | 10 +++++--- tests/e2e/common/uc_volume_tests.py | 9 +++++-- tests/e2e/test_driver.py | 14 ++++++----- 9 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 61b07cb91..cec869027 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -69,12 +69,10 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = ( - retry_stop_after_attempts_duration or 900.0 - ) + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 368edc9a2..9c9988971 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -294,7 +294,7 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = max(proposed_wait, self.delay_max) + proposed_wait = min(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) logger.debug(f"Retrying after {proposed_wait} seconds") time.sleep(proposed_wait) @@ -355,8 +355,8 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: logger.info(f"Received status code {status_code} for {method} request") # Request succeeded. Don't retry. - if status_code == 200: - return False, "200 codes are not retried" + if status_code // 100 == 2: + return False, "2xx codes are not retried" if status_code == 401: return ( diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 801632a41..1a1849bb7 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -194,7 +194,7 @@ def __init__( if _max_redirects: if _max_redirects > self._retry_stop_after_attempts_count: - logger.warn( + logger.warning( "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!" ) urllib3_kwargs = {"redirect": _max_redirects} diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1a35f97da..d3a72c86a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -376,21 +376,17 @@ def _build_client_context(self, server_hostname: str, **kwargs): hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get( - "_retry_stop_after_attempts_count", 30 - ), - retry_delay_min=kwargs.get("_retry_delay_min", 1.0), - retry_delay_max=kwargs.get("_retry_delay_max", 60.0), - retry_stop_after_attempts_duration=kwargs.get( - "_retry_stop_after_attempts_duration", 900.0 - ), - retry_delay_default=kwargs.get("_retry_delay_default", 1.0), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), http_proxy=kwargs.get("_http_proxy"), proxy_username=kwargs.get("_proxy_username"), proxy_password=kwargs.get("_proxy_password"), - pool_connections=kwargs.get("_pool_connections", 1), - pool_maxsize=kwargs.get("_pool_maxsize", 1), + pool_connections=kwargs.get("_pool_connections"), + pool_maxsize=kwargs.get("_pool_maxsize"), user_agent=user_agent, ) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index a296704b4..13fd9ddd2 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -9,7 +9,7 @@ from urllib3.util import make_headers from urllib3.exceptions import MaxRetryError -from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType from databricks.sql.exc import RequestError logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ def __init__(self, client_context): """ self.config = client_context self._pool_manager = None + self._retry_policy = None self._setup_pool_manager() def _setup_pool_manager(self): @@ -69,7 +70,7 @@ def _setup_pool_manager(self): ) # Create retry policy - retry_policy = DatabricksRetryPolicy( + self._retry_policy = DatabricksRetryPolicy( delay_min=self.config.retry_delay_min, delay_max=self.config.retry_delay_max, stop_after_attempts_count=self.config.retry_stop_after_attempts_count, @@ -77,12 +78,17 @@ def _setup_pool_manager(self): delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, ) + + # Initialize the required attributes that DatabricksRetryPolicy expects + # but doesn't initialize in its constructor + self._retry_policy._command_type = None + self._retry_policy._retry_start_time = None # Common pool manager kwargs pool_kwargs = { "num_pools": self.config.pool_connections, "maxsize": self.config.pool_maxsize, - "retries": retry_policy, + "retries": self._retry_policy, "timeout": urllib3.Timeout( connect=self.config.socket_timeout, read=self.config.socket_timeout ) @@ -119,6 +125,14 @@ def _prepare_headers( return request_headers + def _prepare_retry_policy(self): + """Set up the retry policy for the current request.""" + if isinstance(self._retry_policy, DatabricksRetryPolicy): + # Set command type for HTTP requests to OTHER (not database commands) + self._retry_policy.command_type = CommandType.OTHER + # Start the retry timer for duration-based retry limits + self._retry_policy.start_retry_timer() + @contextmanager def request_context( self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs @@ -138,6 +152,10 @@ def request_context( logger.debug("Making %s request to %s", method, url) request_headers = self._prepare_headers(headers) + + # Prepare retry policy for this request + self._prepare_retry_policy() + response = None try: diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index e1c32d68e..e5ff3dcb7 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -247,6 +247,7 @@ def test_retry_exponential_backoff(self, mock_send_telemetry, extra_params): """ retry_policy = self._retry_policy.copy() retry_policy["_retry_delay_min"] = 1 + retry_policy["_retry_delay_max"] = 10 time_start = time.time() with mocked_server_response( @@ -282,9 +283,11 @@ def test_retry_max_duration_not_exceeded(self, extra_params): WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ + retry_policy = self._retry_policy.copy() + retry_policy["_retry_delay_max"] = 60 with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - extra_params = {**extra_params, **self._retry_policy} + extra_params = {**extra_params, **retry_policy} with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 825f830f3..377d51ef4 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -68,15 +68,19 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # REMOVE should succeed remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv'" - - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + # Use minimal retry settings to fail fast for staging operations + extra_params = { + "staging_allowed_local_path": "/", + "_retry_stop_after_attempts_count": 1, + } + with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() cursor.execute(remove_query) # GET after REMOVE should fail with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" + Error, match="too many 404 error responses" ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 72e2f5020..c60e10e6d 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -68,14 +68,19 @@ def test_uc_volume_life_cycle(self, catalog, schema): remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + # Use minimal retry settings to fail fast + extra_params = { + "staging_allowed_local_path": "/", + "_retry_stop_after_attempts_count": 1, + } + with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() cursor.execute(remove_query) # GET after REMOVE should fail with pytest.raises( - Error, match="Staging operation over HTTP was unsuccessful: 404" + Error, match="too many 404 error responses" ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..53b7383e6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -60,12 +60,14 @@ unsafe_logger.addHandler(logging.FileHandler("./tests-unsafe.log")) # manually decorate DecimalTestsMixin to need arrow support -for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): - fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) - setattr(DecimalTestsMixin, name, decorated) +test_loader = loader.TestLoader() +for name in test_loader.getTestCaseNames(DecimalTestsMixin): + if name.startswith("test_"): + fn = getattr(DecimalTestsMixin, name) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) + setattr(DecimalTestsMixin, name, decorated) class PySQLPytestTestCase: From 000d3a360000c0b7f3c5914d061296685224cf5f Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 12:12:39 +0530 Subject: [PATCH 11/35] fix unit Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 4 +++- src/databricks/sql/client.py | 8 ++++++-- src/databricks/sql/common/unified_http_client.py | 6 +++--- tests/unit/test_retry.py | 4 ++-- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index cec869027..e80fac189 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -72,7 +72,9 @@ def __init__( self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d3a72c86a..d2e94df63 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -376,10 +376,14 @@ def _build_client_context(self, server_hostname: str, **kwargs): hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count" + ), retry_delay_min=kwargs.get("_retry_delay_min"), retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration" + ), retry_delay_default=kwargs.get("_retry_delay_default"), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), http_proxy=kwargs.get("_http_proxy"), diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 13fd9ddd2..03f784ee2 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -78,7 +78,7 @@ def _setup_pool_manager(self): delay_default=self.config.retry_delay_default, force_dangerous_codes=self.config.retry_dangerous_codes, ) - + # Initialize the required attributes that DatabricksRetryPolicy expects # but doesn't initialize in its constructor self._retry_policy._command_type = None @@ -152,10 +152,10 @@ def request_context( logger.debug("Making %s request to %s", method, url) request_headers = self._prepare_headers(headers) - + # Prepare retry policy for this request self._prepare_retry_policy() - + response = None try: diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 897a1d111..40096bf08 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -34,7 +34,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - expected_backoff_time = max( + expected_backoff_time = min( self.calculate_backoff_time( 0, retry_policy.delay_min, retry_policy.delay_max ), @@ -57,7 +57,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli expected_backoff_times = [] for attempt in range(num_attempts): expected_backoff_times.append( - max( + min( self.calculate_backoff_time( attempt, retry_policy.delay_min, retry_policy.delay_max ), From cba3da70ce26696b645debd3a6d3dd523f7074f6 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 14:29:12 +0530 Subject: [PATCH 12/35] more fixes Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/retry.py | 2 +- src/databricks/sql/backend/thrift_backend.py | 2 +- src/databricks/sql/client.py | 4 +- src/databricks/sql/cloudfetch/downloader.py | 3 - .../sql/common/unified_http_client.py | 43 ----------- .../sql/telemetry/telemetry_client.py | 72 ++++++++++++++++--- tests/e2e/common/retry_test_mixins.py | 5 +- tests/e2e/common/staging_ingestion_tests.py | 1 + tests/e2e/common/uc_volume_tests.py | 1 + 9 files changed, 68 insertions(+), 65 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 9c9988971..ad8e455f1 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -294,7 +294,7 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: else: proposed_wait = self.get_backoff_time() - proposed_wait = min(proposed_wait, self.delay_max) + proposed_wait = max(proposed_wait, self.delay_max) self.check_proposed_wait(proposed_wait) logger.debug(f"Retrying after {proposed_wait} seconds") time.sleep(proposed_wait) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1a1849bb7..25cc8428a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,7 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - http_client=None, + http_client, **kwargs, ): # Internal arguments in **kwargs: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d2e94df63..74630cebc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -277,7 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), - http_client=http_client, + client_context=client_context, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -299,7 +299,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, - http_client=self.session.http_client, + client_context=client_context, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index cef4ca274..a2a7837f0 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -13,9 +13,6 @@ logger = logging.getLogger(__name__) -# TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. -# But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests - @dataclass class DownloadedFile: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 03f784ee2..bb26ae9de 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -193,38 +193,6 @@ def request( response._body = response.data return response - def upload_file( - self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None - ) -> urllib3.HTTPResponse: - """ - Upload a file using PUT method. - - Args: - url: URL to upload to - file_path: Path to the file to upload - headers: Optional headers - - Returns: - urllib3.HTTPResponse: The response from the server - """ - with open(file_path, "rb") as file_obj: - return self.request("PUT", url, body=file_obj.read(), headers=headers) - - def download_file( - self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None - ) -> None: - """ - Download a file using GET method. - - Args: - url: URL to download from - file_path: Path where to save the downloaded file - headers: Optional headers - """ - response = self.request("GET", url, headers=headers) - with open(file_path, "wb") as file_obj: - file_obj.write(response.data) - def close(self): """Close the underlying connection pools.""" if self._pool_manager: @@ -236,14 +204,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - - -# Compatibility class to maintain requests-like interface for OAuth -class IgnoreNetrcAuth: - """ - Compatibility class for OAuth code that expects requests.auth.AuthBase interface. - This is a no-op auth handler since OAuth handles auth differently. - """ - - def __call__(self, request): - return request diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2785d3cca..9887b67a7 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -1,8 +1,11 @@ import threading import time import logging +import json from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional, TYPE_CHECKING +from concurrent.futures import Future +from datetime import datetime, timezone +from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -36,8 +39,7 @@ import locale from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory - -from src.databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.unified_http_client import UnifiedHttpClient if TYPE_CHECKING: from databricks.sql.client import Connection @@ -151,6 +153,44 @@ def _flush(self): pass +class TelemetryHttpClientSingleton: + """ + Singleton HTTP client for telemetry operations. + + This ensures that telemetry has its own dedicated HTTP client that + is independent of individual connection lifecycles. + """ + + _instance = None + _lock = threading.RLock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._http_client = None + cls._instance._initialized = False + return cls._instance + + def get_http_client(self, client_context): + """Get or create the singleton HTTP client.""" + if not self._initialized and client_context: + with self._lock: + if not self._initialized: + self._http_client = UnifiedHttpClient(client_context) + self._initialized = True + return self._http_client + + def close(self): + """Close the singleton HTTP client.""" + with self._lock: + if self._http_client: + self._http_client.close() + self._http_client = None + self._initialized = False + + class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -169,7 +209,7 @@ def __init__( host_url, executor, batch_size, - http_client, + client_context, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -182,7 +222,10 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = http_client + + # Use singleton HTTP client for telemetry instead of connection-specific client + self._http_client_singleton = TelemetryHttpClientSingleton() + self._http_client = self._http_client_singleton.get_http_client(client_context) def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -246,17 +289,24 @@ def _send_telemetry(self, events): except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) - def _send_with_unified_client(self, url, data, headers): + def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: response = self._http_client.request( - "POST", url, body=data, headers=headers, timeout=900 + "POST", url, body=data, headers=headers, timeout=timeout ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status + response.ok = 200 <= response.status < 300 response.json = ( lambda: json.loads(response.data.decode()) if response.data else {} ) + # Add raise_for_status method + def raise_for_status(): + if not response.ok: + raise Exception(f"HTTP {response.status_code}") + + response.raise_for_status = raise_for_status return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) @@ -452,7 +502,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, - http_client, + client_context, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -475,7 +525,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, - http_client=http_client, + client_context=client_context, ) else: TelemetryClientFactory._clients[ @@ -528,7 +578,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, - http_client: UnifiedHttpClient, + client_context, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -541,7 +591,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=http_client, + client_context=client_context, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index e5ff3dcb7..e1c32d68e 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -247,7 +247,6 @@ def test_retry_exponential_backoff(self, mock_send_telemetry, extra_params): """ retry_policy = self._retry_policy.copy() retry_policy["_retry_delay_min"] = 1 - retry_policy["_retry_delay_max"] = 10 time_start = time.time() with mocked_server_response( @@ -283,11 +282,9 @@ def test_retry_max_duration_not_exceeded(self, extra_params): WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ - retry_policy = self._retry_policy.copy() - retry_policy["_retry_delay_max"] = 60 with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - extra_params = {**extra_params, **retry_policy} + extra_params = {**extra_params, **self._retry_policy} with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 377d51ef4..73aa0a113 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -72,6 +72,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): extra_params = { "staging_allowed_local_path": "/", "_retry_stop_after_attempts_count": 1, + "_retry_delay_max": 10, } with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index c60e10e6d..93e63bd28 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -72,6 +72,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): extra_params = { "staging_allowed_local_path": "/", "_retry_stop_after_attempts_count": 1, + "_retry_delay_max": 10, } with self.connection(extra_params=extra_params) as conn: cursor = conn.cursor() From 2a1f719025224d9ec4ae9e53fbb1359a193948b7 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 11 Aug 2025 15:29:25 +0530 Subject: [PATCH 13/35] more fixes Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 6 ++- src/databricks/sql/auth/authenticators.py | 3 ++ src/databricks/sql/auth/oauth.py | 5 +- src/databricks/sql/backend/thrift_backend.py | 3 +- src/databricks/sql/client.py | 42 ++++++--------- src/databricks/sql/common/feature_flag.py | 9 ++-- tests/unit/test_auth.py | 1 + tests/unit/test_retry.py | 4 +- tests/unit/test_telemetry.py | 55 ++++++++++++++------ tests/unit/test_thrift_backend.py | 48 +++++++++++++++++ 10 files changed, 121 insertions(+), 55 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index cc421e69e..59da2a422 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -19,6 +19,7 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.hostname, cfg.azure_client_id, cfg.azure_client_secret, + http_client, cfg.azure_tenant_id, cfg.azure_workspace_resource_id, ) @@ -34,8 +35,8 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client, cfg.auth_type, - http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -54,7 +55,8 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, - http_client=http_client, + http_client, + cfg.auth_type or "databricks-oauth", ) else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 80f44812c..66e2cbe53 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -190,6 +190,7 @@ def __init__( hostname, azure_client_id, azure_client_secret, + http_client, azure_tenant_id=None, azure_workspace_resource_id=None, ): @@ -200,6 +201,7 @@ def __init__( self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host( hostname ) + self._http_client = http_client def auth_type(self) -> str: return AuthType.AZURE_SP_M2M.value @@ -209,6 +211,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource: token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}", client_id=self.azure_client_id, client_secret=self.azure_client_secret, + http_client=self._http_client, extra_params={"resource": resource}, ) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 7f96a2303..9fdf3955a 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -190,8 +190,7 @@ def __send_auth_code_token_request( data = f"{token_request_body}&code_verifier={verifier}" return self.__send_token_request(token_request_url, data) - @staticmethod - def __send_token_request(token_request_url, data): + def __send_token_request(self, token_request_url, data): headers = { "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", @@ -210,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token): token_request_body = client.prepare_refresh_body( refresh_token=refresh_token, client_id=client.client_id ) - return OAuthManager.__send_token_request(token_request_url, token_request_body) + return self.__send_token_request(token_request_url, token_request_body) @staticmethod def __get_tokens_from_response(oauth_response): diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 25cc8428a..59cf69b6e 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -8,6 +8,7 @@ from typing import List, Optional, Union, Any, TYPE_CHECKING from uuid import UUID +from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.result_set import ThriftResultSet from databricks.sql.telemetry.models.event import StatementType @@ -105,7 +106,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - http_client, + http_client: UnifiedHttpClient, **kwargs, ): # Internal arguments in **kwargs: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 74630cebc..634c7e261 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -799,12 +799,6 @@ def _handle_staging_put( r = self.connection.session.http_client.request( "PUT", presigned_url, body=fh.read(), headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" # fmt: off # HTTP status codes @@ -814,13 +808,15 @@ def _handle_staging_put( NO_CONTENT = 204 # fmt: on - if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + if r.status not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) - if r.status_code == ACCEPTED: + if r.status == ACCEPTED: logger.debug( f"Response code {ACCEPTED} from server indicates ingestion command was accepted " + "but not yet applied on the server. It's possible this command may fail later." @@ -844,23 +840,19 @@ def _handle_staging_get( r = self.connection.session.http_client.request( "GET", presigned_url, headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True - if not r.ok: + if r.status >= 400: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) with open(local_file, "wb") as fp: - fp.write(r.content) + fp.write(r.data) @log_latency(StatementType.SQL) def _handle_staging_remove( @@ -871,16 +863,12 @@ def _handle_staging_remove( r = self.connection.session.http_client.request( "DELETE", presigned_url, headers=headers ) - # Add compatibility attributes for urllib3 response - r.status_code = r.status - if hasattr(r, "data"): - r.content = r.data - r.ok = r.status < 400 - r.text = r.data.decode() if r.data else "" - - if not r.ok: + + if r.status >= 400: + # Decode response data for error message + error_text = r.data.decode() if r.data else "" raise OperationalError( - f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}", + f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", session_id_hex=self.connection.get_session_id_hex(), ) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 1b920b008..2b7e27ab3 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -113,12 +113,11 @@ def _refresh_flags(self): response = self._http_client.request( "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) - # Add compatibility attributes for urllib3 response - response.status_code = response.status - response.json = lambda: json.loads(response.data.decode()) - if response.status_code == 200: - ff_response = FeatureFlagsResponse.from_dict(response.json()) + if response.status == 200: + # Parse JSON response from urllib3 response data + response_data = json.loads(response.data.decode()) + ff_response = FeatureFlagsResponse.from_dict(response_data) self._update_cache_from_response(ff_response) else: # On failure, initialize with an empty dictionary to prevent re-blocking. diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 333782fd8..d574ed27e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -306,6 +306,7 @@ def credential_provider(self): hostname="hostname", azure_client_id="client_id", azure_client_secret="client_secret", + http_client=MagicMock(), azure_tenant_id="tenant_id", ) diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 40096bf08..897a1d111 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -34,7 +34,7 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - expected_backoff_time = min( + expected_backoff_time = max( self.calculate_backoff_time( 0, retry_policy.delay_min, retry_policy.delay_max ), @@ -57,7 +57,7 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli expected_backoff_times = [] for attempt in range(num_attempts): expected_backoff_times.append( - min( + max( self.calculate_backoff_time( attempt, retry_policy.delay_min, retry_policy.delay_max ), diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 989b2351c..0e828497f 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -19,12 +19,17 @@ @pytest.fixture -def mock_telemetry_client(): +@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") +def mock_telemetry_client(mock_singleton_class): """Create a mock telemetry client for testing.""" session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -33,7 +38,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) @@ -212,11 +217,16 @@ def telemetry_system_reset(self): TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - def test_client_lifecycle_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_client_lifecycle_flow(self, mock_singleton_class): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -225,7 +235,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -241,10 +251,15 @@ def test_client_lifecycle_flow(self): client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_disabled_telemetry_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_disabled_telemetry_flow(self, mock_singleton_class): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -252,16 +267,21 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - def test_factory_error_handling(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_factory_error_handling(self, mock_singleton_class): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Simulate initialization error with patch( @@ -274,18 +294,23 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) - def test_factory_shutdown_flow(self): + @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") + def test_factory_shutdown_flow(self, mock_singleton_class): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - mock_http_client = MagicMock() + mock_client_context = MagicMock() + + # Mock the singleton to return a mock HTTP client + mock_singleton = mock_singleton_class.return_value + mock_singleton.get_http_client.return_value = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -295,7 +320,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - http_client=mock_http_client, + client_context=mock_client_context, ) # Factory should be initialized diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0cdb43f5c..396e0e3f1 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -83,6 +83,7 @@ def test_make_request_checks_thrift_status_code(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -102,6 +103,7 @@ def _make_fake_thrift_backend(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() @@ -196,6 +198,7 @@ def test_headers_are_set(self, t_http_client_class): [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) t_http_client_class.return_value.setCustomHeaders.assert_called_with( {"header": "value"} @@ -243,6 +246,7 @@ def test_tls_cert_args_are_propagated( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) mock_ssl_context.load_cert_chain.assert_called_once_with( @@ -329,6 +333,7 @@ def test_tls_no_verify_is_respected( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) self.assertFalse(mock_ssl_context.check_hostname) @@ -353,6 +358,7 @@ def test_tls_verify_hostname_is_respected( [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options, + http_client=MagicMock(), ) self.assertFalse(mock_ssl_context.check_hostname) @@ -370,6 +376,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -385,6 +392,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -400,6 +408,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], @@ -415,6 +424,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=129, ) self.assertEqual( @@ -427,6 +437,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) @@ -437,6 +448,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 @@ -448,6 +460,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _socket_timeout=None, ) self.assertEqual( @@ -559,6 +572,7 @@ def test_make_request_checks_status_code(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) for code in error_codes: @@ -604,6 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -647,6 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -691,6 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -729,6 +746,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command( @@ -772,6 +790,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command( @@ -840,6 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -892,6 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) ( execute_response, @@ -930,6 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -1154,6 +1176,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) arrow_queue, has_more_results, _ = thrift_backend.fetch_results( command_id=Mock(), @@ -1183,6 +1206,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1219,6 +1243,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1252,6 +1277,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1294,6 +1320,7 @@ def test_get_tables_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1340,6 +1367,7 @@ def test_get_columns_calls_client_and_handle_execute_response( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend._handle_execute_response = Mock() thrift_backend._handle_execute_response.return_value = (Mock(), Mock()) @@ -1383,6 +1411,7 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @@ -1397,6 +1426,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) command_id = CommandId.from_thrift_handle(self.operation_handle) thrift_backend.close_command(command_id) @@ -1415,6 +1445,7 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) @@ -1470,6 +1501,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1490,6 +1522,7 @@ def test_create_arrow_table_calls_correct_conversion_method( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1525,6 +1558,7 @@ def test_convert_arrow_based_set_to_arrow_table( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1745,6 +1779,7 @@ def test_make_request_will_retry_GetOperationStatus( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1823,6 +1858,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1855,6 +1891,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(OperationalError) as cm: @@ -1884,6 +1921,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0, @@ -1913,6 +1951,7 @@ def test_make_request_will_read_error_message_headers_if_set( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) error_headers = [ @@ -2037,6 +2076,7 @@ def test_retry_args_passthrough(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **retry_delay_args, ) for arg, val in retry_delay_args.items(): @@ -2068,6 +2108,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2096,6 +2137,7 @@ def test_configuration_passthrough(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) backend.open_session(mock_config, None, None) @@ -2114,6 +2156,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(databricks.sql.Error) as cm: @@ -2141,6 +2184,7 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] @@ -2172,6 +2216,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) backend.open_session({}, None, None) @@ -2191,6 +2236,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False @@ -2237,6 +2283,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), ) with self.assertRaises(InvalidServerResponseError) as cm: @@ -2283,6 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 1dd40a10b0227e837e5adda96816f2d54c22d58d Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 11:38:28 +0530 Subject: [PATCH 14/35] review comments Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 2 +- src/databricks/sql/auth/common.py | 36 +++++----- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 70 +++++-------------- src/databricks/sql/cloudfetch/downloader.py | 22 +++--- src/databricks/sql/common/feature_flag.py | 4 +- .../sql/common/unified_http_client.py | 24 +++++-- src/databricks/sql/session.py | 4 -- .../sql/telemetry/telemetry_client.py | 3 +- src/databricks/sql/utils.py | 59 ++++++++++++++-- 10 files changed, 127 insertions(+), 101 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 59da2a422..a8accac06 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -56,7 +56,7 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.oauth_client_id, cfg.oauth_scopes, http_client, - cfg.auth_type or "databricks-oauth", + cfg.auth_type or AuthType.DATABRICKS_OAUTH.value, ) else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index e80fac189..4ae7afb0b 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,6 +2,8 @@ import logging from typing import Optional, List from urllib.parse import urlparse +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -38,17 +40,17 @@ def __init__( # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, - retry_stop_after_attempts_count: Optional[int] = None, - retry_delay_min: Optional[float] = None, - retry_delay_max: Optional[float] = None, - retry_stop_after_attempts_duration: Optional[float] = None, - retry_delay_default: Optional[float] = None, + retry_stop_after_attempts_count: int = 5, + retry_delay_min: float = 1.0, + retry_delay_max: float = 60.0, + retry_stop_after_attempts_duration: float = 900.0, + retry_delay_default: float = 5.0, retry_dangerous_codes: Optional[List[int]] = None, http_proxy: Optional[str] = None, proxy_username: Optional[str] = None, proxy_password: Optional[str] = None, - pool_connections: Optional[int] = None, - pool_maxsize: Optional[int] = None, + pool_connections: int = 10, + pool_maxsize: int = 20, user_agent: Optional[str] = None, ): self.hostname = hostname @@ -69,19 +71,17 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 - self.retry_delay_min = retry_delay_min or 1.0 - self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = ( - retry_stop_after_attempts_duration or 900.0 - ) - self.retry_delay_default = retry_delay_default or 5.0 + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count + self.retry_delay_min = retry_delay_min + self.retry_delay_max = retry_delay_max + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration + self.retry_delay_default = retry_delay_default self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy self.proxy_username = proxy_username self.proxy_password = proxy_password - self.pool_connections = pool_connections or 10 - self.pool_maxsize = pool_maxsize or 20 + self.pool_connections = pool_connections + self.pool_maxsize = pool_maxsize self.user_agent = user_agent @@ -113,7 +113,9 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.request_context("GET", login_url, allow_redirects=False) as resp: + with http_client.request_context( + HttpMethod.GET, login_url, allow_redirects=False + ) as resp: if resp.status // 100 != 3: raise ValueError( f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 9fdf3955a..09753c9ff 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = self.http_client.request("GET", url=known_config_url) + response = self.http_client.request(HttpMethod.GET, url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status response.json = lambda: json.loads(response.data.decode()) @@ -197,7 +197,7 @@ def __send_token_request(self, token_request_url, data): } # Use unified HTTP client response = self.http_client.request( - "POST", url=token_request_url, body=data, headers=headers + HttpMethod.POST, url=token_request_url, body=data, headers=headers ) # Convert urllib3 response to dict for compatibility return json.loads(response.data.decode()) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 634c7e261..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -31,6 +31,7 @@ transform_paramstyle, ColumnTable, ColumnQueue, + build_client_context, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -52,6 +53,7 @@ from databricks.sql.auth.common import ClientContext from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, @@ -254,14 +256,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) - client_context = self._build_client_context(server_hostname, **kwargs) - http_client = UnifiedHttpClient(client_context) + client_context = build_client_context(server_hostname, __version__, **kwargs) + self.http_client = UnifiedHttpClient(client_context) try: self.session = Session( server_hostname, http_path, - http_client, + self.http_client, http_headers, session_configuration, catalog, @@ -350,50 +352,6 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value - def _build_client_context(self, server_hostname: str, **kwargs): - """Build ClientContext for HTTP client configuration.""" - from databricks.sql.auth.common import ClientContext - from databricks.sql.types import SSLOptions - - # Extract SSL options - ssl_options = SSLOptions( - tls_verify=not kwargs.get("_tls_no_verify", False), - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - # Build user agent - user_agent_entry = kwargs.get("user_agent_entry", "") - if user_agent_entry: - user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" - else: - user_agent = f"PyDatabricksSqlConnector/{__version__}" - - return ClientContext( - hostname=server_hostname, - ssl_options=ssl_options, - socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get( - "_retry_stop_after_attempts_count" - ), - retry_delay_min=kwargs.get("_retry_delay_min"), - retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get( - "_retry_stop_after_attempts_duration" - ), - retry_delay_default=kwargs.get("_retry_delay_default"), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - http_proxy=kwargs.get("_http_proxy"), - proxy_username=kwargs.get("_proxy_username"), - proxy_password=kwargs.get("_proxy_password"), - pool_connections=kwargs.get("_pool_connections"), - pool_maxsize=kwargs.get("_pool_maxsize"), - user_agent=user_agent, - ) - # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self @@ -447,7 +405,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return hasattr(self, "session") and self.session.is_open + return self.session.is_open def cursor( self, @@ -497,6 +455,10 @@ def _close(self, close_cursors=True) -> None: TelemetryClientFactory.close(self.get_session_id_hex()) + # Close HTTP client that was created by this connection + if self.http_client: + self.http_client.close() + def commit(self): """No-op because Databricks does not support transactions""" pass @@ -796,8 +758,8 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = self.connection.session.http_client.request( - "PUT", presigned_url, body=fh.read(), headers=headers + r = self.connection.http_client.request( + HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers ) # fmt: off @@ -837,8 +799,8 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = self.connection.session.http_client.request( - "GET", presigned_url, headers=headers + r = self.connection.http_client.request( + HttpMethod.GET, presigned_url, headers=headers ) # response.ok verifies the status code is not between 400-600. @@ -860,8 +822,8 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = self.connection.session.http_client.request( - "DELETE", presigned_url, headers=headers + r = self.connection.http_client.request( + HttpMethod.DELETE, presigned_url, headers=headers ) if r.status >= 400: diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index a2a7837f0..e6d1c6d10 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -10,6 +10,7 @@ from databricks.sql.types import SSLOptions from databricks.sql.telemetry.latency_logger import log_latency from databricks.sql.telemetry.models.event import StatementType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -79,9 +80,10 @@ def run(self) -> DownloadedFile: """ logger.debug( - "ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format( - self.chunk_id, self.link.startRowOffset, self.link.rowCount - ) + "ResultSetDownloadHandler: starting file download, chunk id %s, offset %s, row count %s", + self.chunk_id, + self.link.startRowOffset, + self.link.rowCount, ) # Check if link is already expired or is expiring @@ -92,7 +94,7 @@ def run(self) -> DownloadedFile: start_time = time.time() with self._http_client.request_context( - method="GET", + method=HttpMethod.GET, url=self.link.fileLink, timeout=self.settings.download_timeout, headers=self.link.httpHeaders, @@ -116,15 +118,15 @@ def run(self) -> DownloadedFile: # The size of the downloaded file should match the size specified from TSparkArrowResultLink if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) + "ResultSetDownloadHandler: downloaded file size %s does not match the expected value %s", + len(decompressed_data), + self.link.bytesNum, ) logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount - ) + "ResultSetDownloadHandler: successfully downloaded file, offset %s, row count %s", + self.link.startRowOffset, + self.link.rowCount, ) return DownloadedFile( diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 2b7e27ab3..8a1cf5bd5 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -5,6 +5,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING +from databricks.sql.common.http import HttpMethod + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -111,7 +113,7 @@ def _refresh_flags(self): headers["User-Agent"] = self._connection.session.useragent_header response = self._http_client.request( - "GET", self._feature_flag_endpoint, headers=headers, timeout=30 + HttpMethod.GET, self._feature_flag_endpoint, headers=headers, timeout=30 ) if response.status == 200: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index bb26ae9de..62cfb3001 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -2,7 +2,7 @@ import ssl import urllib.parse from contextlib import contextmanager -from typing import Dict, Any, Optional, Generator, Union +from typing import Dict, Any, Optional, Generator import urllib3 from urllib3 import PoolManager, ProxyManager @@ -11,6 +11,7 @@ from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType from databricks.sql.exc import RequestError +from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -135,13 +136,17 @@ def _prepare_retry_policy(self): @contextmanager def request_context( - self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, ) -> Generator[urllib3.HTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. Args: - method: HTTP method (GET, POST, PUT, DELETE) + method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request @@ -160,7 +165,7 @@ def request_context( try: response = self._pool_manager.request( - method=method, url=url, headers=request_headers, **kwargs + method=method.value, url=url, headers=request_headers, **kwargs ) yield response except MaxRetryError as e: @@ -174,22 +179,27 @@ def request_context( response.close() def request( - self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, ) -> urllib3.HTTPResponse: """ Make an HTTP request. Args: - method: HTTP method (GET, POST, PUT, DELETE, etc.) + method: HTTP method (HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, etc.) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request Returns: - urllib3.HTTPResponse: The HTTP response object with data pre-loaded + urllib3.HTTPResponse: The HTTP response object with data and metadata pre-loaded """ with self.request_context(method, url, headers=headers, **kwargs) as response: # Read the response data to ensure it's available after context exit + # Note: status and headers remain accessible after close(), only data needs caching response._body = response.data return response diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0cba8be48..d8ba5d125 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -193,7 +193,3 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False - - # Close HTTP client if it exists - if hasattr(self, "http_client") and self.http_client: - self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9887b67a7..29935dc3a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -40,6 +40,7 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod if TYPE_CHECKING: from databricks.sql.client import Connection @@ -293,7 +294,7 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: response = self._http_client.request( - "POST", url, body=data, headers=headers, timeout=timeout + HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ff48e0e91..7b9746df9 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -105,16 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, http_client=http_client, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, ) else: raise AssertionError("Row set type is not valid") @@ -882,3 +882,54 @@ def concat_table_chunks( return ColumnTable(result_table, table_chunks[0].column_names) else: return pyarrow.concat_tables(table_chunks, use_threads=True) + + +def build_client_context(server_hostname: str, version: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{version} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{version}" + + # Build ClientContext kwargs, excluding None values to use defaults + context_kwargs = { + "hostname": server_hostname, + "ssl_options": ssl_options, + "user_agent": user_agent, + } + + # Only add non-None values to let defaults work + for param, kwarg_key in [ + ("socket_timeout", "_socket_timeout"), + ("retry_stop_after_attempts_count", "_retry_stop_after_attempts_count"), + ("retry_delay_min", "_retry_delay_min"), + ("retry_delay_max", "_retry_delay_max"), + ("retry_stop_after_attempts_duration", "_retry_stop_after_attempts_duration"), + ("retry_delay_default", "_retry_delay_default"), + ("retry_dangerous_codes", "_retry_dangerous_codes"), + ("http_proxy", "_http_proxy"), + ("proxy_username", "_proxy_username"), + ("proxy_password", "_proxy_password"), + ("pool_connections", "_pool_connections"), + ("pool_maxsize", "_pool_maxsize"), + ]: + value = kwargs.get(kwarg_key) + if value is not None: + context_kwargs[param] = value + + return ClientContext(**context_kwargs) From 3847acac62be645cba8294c513399f75eb9bc1b2 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 12:23:53 +0530 Subject: [PATCH 15/35] fix warnings Signed-off-by: Vikrant Puppala --- tests/unit/test_telemetry.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 0e828497f..ab07b400c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -349,7 +349,11 @@ def test_connection_failure_sends_correct_telemetry_payload( """ error_message = "Could not connect to host" - mock_session.side_effect = Exception(error_message) + # Set up the mock to create a session instance first, then make open() fail + mock_session_instance = MagicMock() + mock_session_instance.is_open = False # Ensure cleanup is safe + mock_session_instance.open.side_effect = Exception(error_message) + mock_session.return_value = mock_session_instance try: sql.connect(server_hostname="test-host", http_path="/test-path") @@ -391,6 +395,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() @@ -418,6 +423,7 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() @@ -445,6 +451,7 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False # Connection starts closed for test cleanup # Set up mock HTTP client on the session mock_http_client = MagicMock() From d9a4797bd54632375eb3487233926c208e2abfbc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 12:41:42 +0530 Subject: [PATCH 16/35] fix check-types Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 30 ++++++++++--------- src/databricks/sql/utils.py | 50 +++++++++++++------------------ 2 files changed, 37 insertions(+), 43 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 4ae7afb0b..36a0f3707 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -40,17 +40,17 @@ def __init__( # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, - retry_stop_after_attempts_count: int = 5, - retry_delay_min: float = 1.0, - retry_delay_max: float = 60.0, - retry_stop_after_attempts_duration: float = 900.0, - retry_delay_default: float = 5.0, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, retry_dangerous_codes: Optional[List[int]] = None, http_proxy: Optional[str] = None, proxy_username: Optional[str] = None, proxy_password: Optional[str] = None, - pool_connections: int = 10, - pool_maxsize: int = 20, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, ): self.hostname = hostname @@ -71,17 +71,19 @@ def __init__( # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout - self.retry_stop_after_attempts_count = retry_stop_after_attempts_count - self.retry_delay_min = retry_delay_min - self.retry_delay_max = retry_delay_max - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration - self.retry_delay_default = retry_delay_default + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 5 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 10.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 300.0 + ) + self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy self.proxy_username = proxy_username self.proxy_password = proxy_password - self.pool_connections = pool_connections - self.pool_maxsize = pool_maxsize + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 7b9746df9..ce2ba5eaf 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence from dateutil import parser import datetime @@ -9,7 +9,6 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union, Sequence import re import lz4.frame @@ -906,30 +905,23 @@ def build_client_context(server_hostname: str, version: str, **kwargs): else: user_agent = f"PyDatabricksSqlConnector/{version}" - # Build ClientContext kwargs, excluding None values to use defaults - context_kwargs = { - "hostname": server_hostname, - "ssl_options": ssl_options, - "user_agent": user_agent, - } - - # Only add non-None values to let defaults work - for param, kwarg_key in [ - ("socket_timeout", "_socket_timeout"), - ("retry_stop_after_attempts_count", "_retry_stop_after_attempts_count"), - ("retry_delay_min", "_retry_delay_min"), - ("retry_delay_max", "_retry_delay_max"), - ("retry_stop_after_attempts_duration", "_retry_stop_after_attempts_duration"), - ("retry_delay_default", "_retry_delay_default"), - ("retry_dangerous_codes", "_retry_dangerous_codes"), - ("http_proxy", "_http_proxy"), - ("proxy_username", "_proxy_username"), - ("proxy_password", "_proxy_password"), - ("pool_connections", "_pool_connections"), - ("pool_maxsize", "_pool_maxsize"), - ]: - value = kwargs.get(kwarg_key) - if value is not None: - context_kwargs[param] = value - - return ClientContext(**context_kwargs) + # Explicitly construct ClientContext with proper types + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + user_agent=user_agent, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration" + ), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections"), + pool_maxsize=kwargs.get("_pool_maxsize"), + ) From ba2a3a9827478dad7e983c8d0ac4c9e49d4bcb21 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 22:38:22 +0530 Subject: [PATCH 17/35] remove separate http client for telemetry Signed-off-by: Vikrant Puppala --- src/databricks/sql/client.py | 4 +- .../sql/telemetry/telemetry_client.py | 55 +++-------------- tests/unit/test_telemetry.py | 59 +++++-------------- 3 files changed, 26 insertions(+), 92 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..ecdf66401 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -279,7 +279,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), - client_context=client_context, + http_client=self.http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -301,7 +301,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, - client_context=client_context, + http_client=self.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 29935dc3a..f933885cf 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -154,44 +154,6 @@ def _flush(self): pass -class TelemetryHttpClientSingleton: - """ - Singleton HTTP client for telemetry operations. - - This ensures that telemetry has its own dedicated HTTP client that - is independent of individual connection lifecycles. - """ - - _instance = None - _lock = threading.RLock() - - def __new__(cls): - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._http_client = None - cls._instance._initialized = False - return cls._instance - - def get_http_client(self, client_context): - """Get or create the singleton HTTP client.""" - if not self._initialized and client_context: - with self._lock: - if not self._initialized: - self._http_client = UnifiedHttpClient(client_context) - self._initialized = True - return self._http_client - - def close(self): - """Close the singleton HTTP client.""" - with self._lock: - if self._http_client: - self._http_client.close() - self._http_client = None - self._initialized = False - - class TelemetryClient(BaseTelemetryClient): """ Telemetry client class that handles sending telemetry events in batches to the server. @@ -210,7 +172,7 @@ def __init__( host_url, executor, batch_size, - client_context, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -224,9 +186,8 @@ def __init__( self._host_url = host_url self._executor = executor - # Use singleton HTTP client for telemetry instead of connection-specific client - self._http_client_singleton = TelemetryHttpClientSingleton() - self._http_client = self._http_client_singleton.get_http_client(client_context) + # Use the provided HTTP client directly + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -503,7 +464,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, - client_context, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -526,7 +487,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, - client_context=client_context, + http_client=http_client, ) else: TelemetryClientFactory._clients[ @@ -579,10 +540,10 @@ def connection_failure_log( host_url: str, http_path: str, port: int, - client_context, + http_client, user_agent: Optional[str] = None, ): - """Send error telemetry when connection creation fails, without requiring a session""" + """Send error telemetry when connection creation fails, using existing HTTP client""" UNAUTH_DUMMY_SESSION_ID = "unauth_session_id" @@ -592,7 +553,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=client_context, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index ab07b400c..ee0590ff8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -19,17 +19,12 @@ @pytest.fixture -@patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") -def mock_telemetry_client(mock_singleton_class): +def mock_telemetry_client(): """Create a mock telemetry client for testing.""" session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -38,7 +33,7 @@ def mock_telemetry_client(mock_singleton_class): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) @@ -217,16 +212,11 @@ def telemetry_system_reset(self): TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_client_lifecycle_flow(self, mock_singleton_class): + def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -235,7 +225,7 @@ def test_client_lifecycle_flow(self, mock_singleton_class): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -248,18 +238,11 @@ def test_client_lifecycle_flow(self, mock_singleton_class): mock_close.assert_called_once() # Should get NoopTelemetryClient after close - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) - assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_disabled_telemetry_flow(self, mock_singleton_class): - """Test that disabled telemetry uses NoopTelemetryClient.""" + def test_disabled_telemetry_creates_noop_client(self): + """Test that disabled telemetry creates NoopTelemetryClient.""" session_id_hex = "test-session" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -267,21 +250,16 @@ def test_disabled_telemetry_flow(self, mock_singleton_class): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_factory_error_handling(self, mock_singleton_class): + def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -294,23 +272,18 @@ def test_factory_error_handling(self, mock_singleton_class): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient client = TelemetryClientFactory.get_telemetry_client(session_id) assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.telemetry.telemetry_client.TelemetryHttpClientSingleton") - def test_factory_shutdown_flow(self, mock_singleton_class): + def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" - mock_client_context = MagicMock() - - # Mock the singleton to return a mock HTTP client - mock_singleton = mock_singleton_class.return_value - mock_singleton.get_http_client.return_value = MagicMock() + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -320,7 +293,7 @@ def test_factory_shutdown_flow(self, mock_singleton_class): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - client_context=mock_client_context, + http_client=mock_http_client, ) # Factory should be initialized From d1f045ebe252883997bd1f67643da6da0cff5ada Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Tue, 12 Aug 2025 22:55:27 +0530 Subject: [PATCH 18/35] more clean up Signed-off-by: Vikrant Puppala --- src/databricks/sql/result_set.py | 2 +- .../sql/telemetry/telemetry_client.py | 24 +++++++------------ 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 77673db9a..6c4c3a43a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,7 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, - http_client=connection.session.http_client, + http_client=connection.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index f933885cf..f6ad4433d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -257,18 +257,6 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): response = self._http_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) - # Convert urllib3 response to requests-like response for compatibility - response.status_code = response.status - response.ok = 200 <= response.status < 300 - response.json = ( - lambda: json.loads(response.data.decode()) if response.data else {} - ) - # Add raise_for_status method - def raise_for_status(): - if not response.ok: - raise Exception(f"HTTP {response.status_code}") - - response.raise_for_status = raise_for_status return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e) @@ -279,14 +267,18 @@ def _telemetry_request_callback(self, future, sent_count: int): try: response = future.result() - if not response.ok: + # Check if response is successful (urllib3 uses response.status) + is_success = 200 <= response.status < 300 + if not is_success: logger.debug( "Telemetry request failed with status code: %s, response: %s", - response.status_code, - response.text, + response.status, + response.data.decode() if response.data else "", ) - telemetry_response = TelemetryResponse(**response.json()) + # Parse JSON response (urllib3 uses response.data) + response_data = json.loads(response.data.decode()) if response.data else {} + telemetry_response = TelemetryResponse(**response_data) logger.debug( "Pushed Telemetry logs with success count: %s, error count: %s", From 4e6623009be8c21c46088c4da420b1078dc80c59 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 13:03:07 +0530 Subject: [PATCH 19/35] remove excess release_connection call --- src/databricks/sql/backend/thrift_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 0df35bab9..1654a1d5a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -619,8 +619,7 @@ def close_session(self, session_id: SessionId) -> None: try: self.make_request(self._client.CloseSession, req) finally: - self._transport.release_connection() - self._transport.close() + self._transport.close() def _check_command_not_in_error_or_closed_state( self, op_handle, get_operations_resp From 67020f1afb82dd15527a94ce90c97e3b5c6b6e2c Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 21:59:58 +0530 Subject: [PATCH 20/35] formatting (black) - fix some closures --- src/databricks/sql/backend/thrift_backend.py | 4 +-- src/databricks/sql/client.py | 2 ++ .../sql/telemetry/telemetry_client.py | 2 +- tests/e2e/common/staging_ingestion_tests.py | 4 +-- tests/e2e/common/uc_volume_tests.py | 4 +-- tests/e2e/test_concurrent_telemetry.py | 8 +++-- tests/e2e/test_driver.py | 6 ++-- tests/unit/test_auth.py | 32 +++++++++++------ tests/unit/test_cloud_fetch_queue.py | 36 ++++++++++++------- tests/unit/test_downloader.py | 13 ++++--- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_telemetry.py | 32 +++++++++++------ tests/unit/test_thrift_backend.py | 16 ++++----- 13 files changed, 101 insertions(+), 60 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1654a1d5a..b089eacd5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.release_connection() + self._transport.close() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.release_connection() + self._transport.close() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..0d4f71ae3 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,6 +284,8 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) + if self.http_client: + self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..fb5c3a116 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,6 +359,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._http_client.close() class TelemetryClientFactory: @@ -460,7 +461,6 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 73aa0a113..5b5f4e693 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -80,9 +80,7 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 93e63bd28..0eeb22789 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -80,9 +80,7 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises( - Error, match="too many 404 error responses" - ): + with pytest.raises(Error, match="too many 404 error responses"): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index d2ac4227d..615a7245e 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -122,9 +122,13 @@ def execute_query_worker(thread_id): response = future.result() # Check status using urllib3 method (response.status instead of response.raise_for_status()) if response.status >= 400: - raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}") + raise Exception( + f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}" + ) # Parse JSON using urllib3 method (response.data.decode() instead of response.json()) - response_data = json.loads(response.data.decode()) if response.data else {} + response_data = ( + json.loads(response.data.decode()) if response.data else {} + ) captured_responses.append(response_data) except Exception as e: captured_exceptions.append(e) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 53b7383e6..c8ae8a0bc 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -64,9 +64,9 @@ for name in test_loader.getTestCaseNames(DecimalTestsMixin): if name.startswith("test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( - fn - ) + decorated = skipUnless( + pysql_supports_arrow(), "Decimal tests need arrow support" + )(fn) setattr(DecimalTestsMixin, name, decorated) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index a5ad7562e..e20c58d3d 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -163,7 +165,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -179,7 +183,9 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -189,7 +195,9 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) + get_python_sql_connector_auth_provider( + "foo.cloud.databricks.com", mock_http_client, **kwargs + ) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -198,7 +206,9 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client + ) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -259,16 +269,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -279,16 +289,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 0c3fc7103..aeaf5bce6 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,22 +13,24 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): - def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + def create_queue( + self, schema_bytes=None, result_links=None, description=None, **kwargs + ): """Helper method to create ThriftCloudFetchQueue with sensible defaults""" # Set up defaults for commonly used parameters defaults = { - 'max_download_threads': 10, - 'ssl_options': SSLOptions(), - 'session_id_hex': Mock(), - 'statement_id': Mock(), - 'chunk_id': 0, - 'start_row_offset': 0, - 'lz4_compressed': True, + "max_download_threads": 10, + "ssl_options": SSLOptions(), + "session_id_hex": Mock(), + "statement_id": Mock(), + "chunk_id": 0, + "start_row_offset": 0, + "lz4_compressed": True, } - + # Override defaults with any provided kwargs defaults.update(kwargs) - + mock_http_client = MagicMock() return utils.ThriftCloudFetchQueue( schema_bytes=schema_bytes or MagicMock(), @@ -198,7 +200,12 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None @@ -277,7 +284,12 @@ def test_remaining_rows_multiple_tables_fully_returned( def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + description = [ + ("col0", "uint32"), + ("col1", "uint32"), + ("col2", "uint32"), + ("col3", "uint32"), + ] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 00b1b849a..4d3570dc6 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time): self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) # Patch the log metrics method to avoid division by zero - with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -160,11 +160,16 @@ def test_run_compressed_successful(self, mock_time): result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" # Setup mock HTTP response using helper method - self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + self._setup_mock_http_response( + mock_http_client, status=200, data=compressed_bytes + ) # Mock the decompression method and log metrics to avoid issues - with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ - patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): + with patch.object( + downloader.ResultSetDownloadHandler, + "_decompress_data", + return_value=file_bytes, + ), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 6471cb4fd..00e6d4939 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -520,7 +520,7 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 738c617bd..b8430b9fc 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -27,7 +27,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -85,7 +87,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -221,7 +223,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -289,7 +293,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -372,8 +378,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -400,8 +408,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -428,8 +438,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index d4d501c64..a71bce597 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -618,7 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -662,7 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -707,7 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -859,7 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -912,7 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) ( execute_response, @@ -951,7 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -2115,7 +2115,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2337,7 +2337,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 496d7f7e949b1e4a51c2d4eb2c2318752093eea6 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:00:52 +0530 Subject: [PATCH 21/35] Revert "formatting (black) - fix some closures" This reverts commit 67020f1afb82dd15527a94ce90c97e3b5c6b6e2c. --- src/databricks/sql/backend/thrift_backend.py | 4 +-- src/databricks/sql/client.py | 2 -- .../sql/telemetry/telemetry_client.py | 2 +- tests/e2e/common/staging_ingestion_tests.py | 4 ++- tests/e2e/common/uc_volume_tests.py | 4 ++- tests/e2e/test_concurrent_telemetry.py | 8 ++--- tests/e2e/test_driver.py | 6 ++-- tests/unit/test_auth.py | 32 ++++++----------- tests/unit/test_cloud_fetch_queue.py | 36 +++++++------------ tests/unit/test_downloader.py | 13 +++---- tests/unit/test_sea_queue.py | 2 +- tests/unit/test_telemetry.py | 32 ++++++----------- tests/unit/test_thrift_backend.py | 16 ++++----- 13 files changed, 60 insertions(+), 101 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b089eacd5..1654a1d5a 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.close() + self._transport.release_connection() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.close() + self._transport.release_connection() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0d4f71ae3..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,8 +284,6 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) - if self.http_client: - self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fb5c3a116..71fcc40c6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,7 +359,6 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() - self._http_client.close() class TelemetryClientFactory: @@ -461,6 +460,7 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: + with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index 5b5f4e693..73aa0a113 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -80,7 +80,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises(Error, match="too many 404 error responses"): + with pytest.raises( + Error, match="too many 404 error responses" + ): cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/16/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 0eeb22789..93e63bd28 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -80,7 +80,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises(Error, match="too many 404 error responses"): + with pytest.raises( + Error, match="too many 404 error responses" + ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 615a7245e..d2ac4227d 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -122,13 +122,9 @@ def execute_query_worker(thread_id): response = future.result() # Check status using urllib3 method (response.status instead of response.raise_for_status()) if response.status >= 400: - raise Exception( - f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}" - ) + raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}") # Parse JSON using urllib3 method (response.data.decode() instead of response.json()) - response_data = ( - json.loads(response.data.decode()) if response.data else {} - ) + response_data = json.loads(response.data.decode()) if response.data else {} captured_responses.append(response_data) except Exception as e: captured_exceptions.append(e) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c8ae8a0bc..53b7383e6 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -64,9 +64,9 @@ for name in test_loader.getTestCaseNames(DecimalTestsMixin): if name.startswith("test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless( - pysql_supports_arrow(), "Decimal tests need arrow support" - )(fn) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) setattr(DecimalTestsMixin, name, decorated) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index e20c58d3d..a5ad7562e 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,9 +145,7 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -165,9 +163,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -183,9 +179,7 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client, **kwargs - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -195,9 +189,7 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider( - "foo.cloud.databricks.com", mock_http_client, **kwargs - ) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -206,9 +198,7 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider( - hostname, mock_http_client - ) + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -269,16 +259,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -289,16 +279,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index aeaf5bce6..0c3fc7103 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,24 +13,22 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): - def create_queue( - self, schema_bytes=None, result_links=None, description=None, **kwargs - ): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): """Helper method to create ThriftCloudFetchQueue with sensible defaults""" # Set up defaults for commonly used parameters defaults = { - "max_download_threads": 10, - "ssl_options": SSLOptions(), - "session_id_hex": Mock(), - "statement_id": Mock(), - "chunk_id": 0, - "start_row_offset": 0, - "lz4_compressed": True, + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, } - + # Override defaults with any provided kwargs defaults.update(kwargs) - + mock_http_client = MagicMock() return utils.ThriftCloudFetchQueue( schema_bytes=schema_bytes or MagicMock(), @@ -200,12 +198,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [ - ("col0", "uint32"), - ("col1", "uint32"), - ("col2", "uint32"), - ("col3", "uint32"), - ] + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None @@ -284,12 +277,7 @@ def test_remaining_rows_multiple_tables_fully_returned( def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() # Create description that matches the 4-column schema - description = [ - ("col0", "uint32"), - ("col1", "uint32"), - ("col2", "uint32"), - ("col3", "uint32"), - ] + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 4d3570dc6..00b1b849a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -131,7 +131,7 @@ def test_run_uncompressed_successful(self, mock_time): self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) # Patch the log metrics method to avoid division by zero - with patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -160,16 +160,11 @@ def test_run_compressed_successful(self, mock_time): result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" # Setup mock HTTP response using helper method - self._setup_mock_http_response( - mock_http_client, status=200, data=compressed_bytes - ) + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) # Mock the decompression method and log metrics to avoid issues - with patch.object( - downloader.ResultSetDownloadHandler, - "_decompress_data", - return_value=file_bytes, - ), patch.object(downloader.ResultSetDownloadHandler, "_log_download_metrics"): + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 00e6d4939..6471cb4fd 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -520,7 +520,7 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index b8430b9fc..738c617bd 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -27,9 +27,7 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -87,7 +85,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -223,9 +221,7 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -293,9 +289,7 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -378,10 +372,8 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -408,10 +400,8 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -438,10 +428,8 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index a71bce597..d4d501c64 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -618,7 +618,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -662,7 +662,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) execute_response, _ = thrift_backend._handle_execute_response( @@ -707,7 +707,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -859,7 +859,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) with self.assertRaises(DatabaseError) as cm: @@ -912,7 +912,7 @@ def test_handle_execute_response_can_handle_without_direct_results( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) ( execute_response, @@ -951,7 +951,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -2115,7 +2115,7 @@ def test_retry_args_bounding(self, mock_http_client): [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **retry_delay_args, ) retry_delay_expected_vals = { @@ -2337,7 +2337,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( From 84ec33a01ca33d0580e65aa686fe51921d094dd0 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:02:19 +0530 Subject: [PATCH 22/35] add more http_client closures --- src/databricks/sql/backend/thrift_backend.py | 4 ++-- src/databricks/sql/client.py | 6 ++++-- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 1654a1d5a..b089eacd5 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -232,7 +232,7 @@ def __init__( try: self._transport.open() except: - self._transport.release_connection() + self._transport.close() raise self._request_lock = threading.RLock() @@ -607,7 +607,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId: self._session_id_hex = session_id.hex_guid return session_id except: - self._transport.release_connection() + self._transport.close() raise def close_session(self, session_id: SessionId) -> None: diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcacf..8150b9663 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,6 +284,7 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) + self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( @@ -362,8 +363,9 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " - "{}".format(self.get_session_id_hex()) + "Closing unclosed connection for session " "{}".format( + self.get_session_id_hex() + ) ) try: self._close(close_cursors=False) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..fb5c3a116 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -359,6 +359,7 @@ def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + self._http_client.close() class TelemetryClientFactory: @@ -460,7 +461,6 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: - with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() From 76ce5ce3fe083b25297e142a22b65d759fab1556 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Wed, 13 Aug 2025 22:13:20 +0530 Subject: [PATCH 23/35] remove excess close call --- src/databricks/sql/client.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 8150b9663..3cd7bcacf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -284,7 +284,6 @@ def read(self) -> Optional[OAuthToken]: if hasattr(self, "session") else None, ) - self.http_client.close() raise e self.use_inline_params = self._set_use_inline_params_with_warning( @@ -363,9 +362,8 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): if self.open: logger.debug( - "Closing unclosed connection for session " "{}".format( - self.get_session_id_hex() - ) + "Closing unclosed connection for session " + "{}".format(self.get_session_id_hex()) ) try: self._close(close_cursors=False) From 4452725590dbd56d2b17dde1bd1c1e3f15c1ba3a Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Thu, 14 Aug 2025 10:42:33 +0530 Subject: [PATCH 24/35] wait for _flush before closing HTTP client --- .../sql/telemetry/telemetry_client.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index fb5c3a116..2a13d8747 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,6 +4,7 @@ import json from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future +from concurrent.futures import wait from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -182,6 +183,7 @@ def __init__( self._user_agent = None self._events_batch = [] self._lock = threading.RLock() + self._pending_futures = set() self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -245,6 +247,9 @@ def _send_telemetry(self, events): timeout=900, ) + with self._lock: + self._pending_futures.add(future) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -303,6 +308,9 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) + finally: + with self._lock: + self._pending_futures.discard(future) def _export_telemetry_log(self, **telemetry_event_kwargs): """ @@ -356,9 +364,20 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Flush remaining events before closing""" + """Flush remaining events and wait for them to complete before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._flush() + + with self._lock: + futures_to_wait_on = list(self._pending_futures) + + if futures_to_wait_on: + logger.debug( + "Waiting for %s pending telemetry requests to complete.", + len(futures_to_wait_on), + ) + wait(futures_to_wait_on) + self._http_client.close() From d90ac80a693e0c52ddb40551093225c4d0af5c60 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Thu, 14 Aug 2025 10:53:19 +0530 Subject: [PATCH 25/35] make close() async --- src/databricks/sql/telemetry/telemetry_client.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2a13d8747..7245b64d5 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -364,8 +364,15 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Flush remaining events and wait for them to complete before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + """Schedules the client to be closed in the background.""" + logger.debug( + "Scheduling background closure for TelemetryClient of connection %s", + self._session_id_hex, + ) + self._executor.submit(self._close_and_wait) + + def _close_and_wait(self): + """Flush remaining events and wait for them to complete before closing.""" self._flush() with self._lock: From c78bace62b05f6d47cae47dbbc2cb4c9a31ef5de Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Tue, 28 Oct 2025 17:46:44 +0530 Subject: [PATCH 26/35] simplify close_session (remove secondary _close_session invocation) --- src/databricks/sql/backend/sea/backend.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 68f41084c..3394911d0 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -273,7 +273,7 @@ def open_session( return SessionId.from_sea_session_id(session_id) - def _close_session(self, session_id: SessionId) -> None: + def close_session(self, session_id: SessionId) -> None: """ Closes an existing session with the Databricks SQL service. @@ -285,9 +285,9 @@ def _close_session(self, session_id: SessionId) -> None: OperationalError: If there's an error closing the session """ - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() + if sea_session_id is None: + raise ValueError("Not a valid SEA session ID") request_data = DeleteSessionRequest( warehouse_id=self.warehouse_id, @@ -300,21 +300,7 @@ def _close_session(self, session_id: SessionId) -> None: data=request_data.to_dict(), ) - def close_session(self, session_id: SessionId) -> None: - """ - Closes the session and the underlying HTTP client. - - Args: - session_id: The session identifier returned by open_session() - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error closing the session - """ - - logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) - - self._close_session(session_id) + # close the HTTP client self._http_client.close() def _extract_description_from_manifest( From bedfc06aa2e14cdc3e675f1546795b6d582b53c6 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Tue, 28 Oct 2025 17:55:30 +0530 Subject: [PATCH 27/35] simplify changes --- src/databricks/sql/auth/thrift_http_client.py | 1 + src/databricks/sql/backend/sea/backend.py | 2 ++ src/databricks/sql/telemetry/telemetry_client.py | 10 +++++----- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 31a19870d..34dd078bb 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -112,6 +112,7 @@ def startRetryTimer(self): self.retry_policy and self.retry_policy.start_retry_timer() def open(self): + # self.__pool replaces the self.__http used by the original THttpClient _pool_kwargs = {"maxsize": self.max_connections} diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 3394911d0..5f10f2df4 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -285,6 +285,8 @@ def close_session(self, session_id: SessionId) -> None: OperationalError: If there's an error closing the session """ + logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) + sea_session_id = session_id.to_sea_session_id() if sea_session_id is None: raise ValueError("Not a valid SEA session ID") diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 7245b64d5..c2b17f25c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -364,7 +364,7 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Schedules the client to be closed in the background.""" + """Schedules client closure.""" logger.debug( "Scheduling background closure for TelemetryClient of connection %s", self._session_id_hex, @@ -376,14 +376,14 @@ def _close_and_wait(self): self._flush() with self._lock: - futures_to_wait_on = list(self._pending_futures) + pending_events = list(self._pending_futures) - if futures_to_wait_on: + if pending_events: logger.debug( "Waiting for %s pending telemetry requests to complete.", - len(futures_to_wait_on), + len(pending_events), ) - wait(futures_to_wait_on) + wait(pending_events) self._http_client.close() From a36353b7938867311798f49bc628743dd373623b Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Tue, 28 Oct 2025 17:59:20 +0530 Subject: [PATCH 28/35] simplify diff --- src/databricks/sql/auth/thrift_http_client.py | 1 + src/databricks/sql/telemetry/telemetry_client.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 34dd078bb..1501e91a1 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -162,6 +162,7 @@ def isOpen(self): return self.__resp is not None def flush(self): + # Pull data out of buffer that will be sent in this request data = self.__wbuf.getvalue() self.__wbuf = BytesIO() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c2b17f25c..98dc3f851 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -364,9 +364,9 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Schedules client closure.""" + """Schedule client closure.""" logger.debug( - "Scheduling background closure for TelemetryClient of connection %s", + "Scheduling closure for TelemetryClient of connection %s", self._session_id_hex, ) self._executor.submit(self._close_and_wait) From 76fb62340d491281374c34f30b97432d603fd00c Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Tue, 28 Oct 2025 18:07:53 +0530 Subject: [PATCH 29/35] simplify imports, log TelemetryClient closure --- src/databricks/sql/telemetry/telemetry_client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 98dc3f851..ca02fdd9a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,9 +2,7 @@ import time import logging import json -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures import Future -from concurrent.futures import wait +from concurrent.futures import ThreadPoolExecutor, wait from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -385,6 +383,7 @@ def _close_and_wait(self): ) wait(pending_events) + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) self._http_client.close() From 66192b4bda01f0684cd5c6a458a7a9b9b3103a99 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Tue, 28 Oct 2025 18:08:48 +0530 Subject: [PATCH 30/35] simplify diff --- src/databricks/sql/telemetry/telemetry_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index ca02fdd9a..c7c4289ec 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -486,6 +486,7 @@ def initialize_telemetry_client( ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: + with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() From 7b4f6cbe6d51acf793e51d2f1ae629ca0db699d9 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Mon, 1 Dec 2025 01:14:11 +0530 Subject: [PATCH 31/35] Revert "Merge branch 'main' into close-conn" This reverts commit 171439523fc346b7a9243644c2e77fabbc16fb76, reversing changes made to 66192b4bda01f0684cd5c6a458a7a9b9b3103a99. --- .github/workflows/daily-telemetry-e2e.yml | 87 --- .github/workflows/integration.yml | 8 +- CHANGELOG.md | 8 - README.md | 6 - TRANSACTIONS.md | 387 ------------ examples/README.md | 1 - examples/transactions.py | 47 -- poetry.lock | 36 +- pyproject.toml | 3 +- src/databricks/sql/__init__.py | 5 +- src/databricks/sql/auth/common.py | 2 - src/databricks/sql/backend/sea/backend.py | 2 +- src/databricks/sql/client.py | 343 +--------- src/databricks/sql/common/feature_flag.py | 8 +- .../sql/common/unified_http_client.py | 52 +- src/databricks/sql/exc.py | 38 -- src/databricks/sql/session.py | 21 - .../sql/telemetry/circuit_breaker_manager.py | 112 ---- .../sql/telemetry/latency_logger.py | 289 +++++---- src/databricks/sql/telemetry/models/event.py | 109 +--- .../sql/telemetry/telemetry_client.py | 126 ++-- .../sql/telemetry/telemetry_push_client.py | 201 ------ src/databricks/sql/utils.py | 3 - tests/e2e/test_circuit_breaker.py | 232 ------- tests/e2e/test_telemetry_e2e.py | 343 ---------- tests/e2e/test_transactions.py | 598 ------------------ .../unit/test_circuit_breaker_http_client.py | 208 ------ tests/unit/test_circuit_breaker_manager.py | 160 ----- tests/unit/test_client.py | 477 +------------- tests/unit/test_telemetry.py | 465 +------------- tests/unit/test_telemetry_push_client.py | 213 ------- .../test_telemetry_request_error_handling.py | 96 --- tests/unit/test_unified_http_client.py | 136 ---- 33 files changed, 227 insertions(+), 4595 deletions(-) delete mode 100644 .github/workflows/daily-telemetry-e2e.yml delete mode 100644 TRANSACTIONS.md delete mode 100644 examples/transactions.py delete mode 100644 src/databricks/sql/telemetry/circuit_breaker_manager.py delete mode 100644 src/databricks/sql/telemetry/telemetry_push_client.py delete mode 100644 tests/e2e/test_circuit_breaker.py delete mode 100644 tests/e2e/test_telemetry_e2e.py delete mode 100644 tests/e2e/test_transactions.py delete mode 100644 tests/unit/test_circuit_breaker_http_client.py delete mode 100644 tests/unit/test_circuit_breaker_manager.py delete mode 100644 tests/unit/test_telemetry_push_client.py delete mode 100644 tests/unit/test_telemetry_request_error_handling.py delete mode 100644 tests/unit/test_unified_http_client.py diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml deleted file mode 100644 index 3d61cf177..000000000 --- a/.github/workflows/daily-telemetry-e2e.yml +++ /dev/null @@ -1,87 +0,0 @@ -name: Daily Telemetry E2E Tests - -on: - schedule: - - cron: '0 0 * * 0' # Run every Sunday at midnight UTC - - workflow_dispatch: # Allow manual triggering - inputs: - test_pattern: - description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' - required: false - default: 'tests/e2e/test_telemetry_e2e.py' - type: string - -jobs: - telemetry-e2e-tests: - runs-on: ubuntu-latest - environment: azure-prod - - env: - DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} - DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} - DATABRICKS_CATALOG: peco - DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} - - steps: - #---------------------------------------------- - # check-out repo and set-up python - #---------------------------------------------- - - name: Check out repository - uses: actions/checkout@v4 - - - name: Set up python - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - #---------------------------------------------- - # ----- install & configure poetry ----- - #---------------------------------------------- - - name: Install Poetry - uses: snok/install-poetry@v1 - with: - virtualenvs-create: true - virtualenvs-in-project: true - installer-parallel: true - - #---------------------------------------------- - # load cached venv if cache exists - #---------------------------------------------- - - name: Load cached venv - id: cached-poetry-dependencies - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} - - #---------------------------------------------- - # install dependencies if cache does not exist - #---------------------------------------------- - - name: Install dependencies - run: poetry install --no-interaction --all-extras - - #---------------------------------------------- - # run telemetry E2E tests - #---------------------------------------------- - - name: Run telemetry E2E tests - run: | - TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" - echo "Running tests: $TEST_PATTERN" - poetry run python -m pytest $TEST_PATTERN -v -s - - #---------------------------------------------- - # upload test results on failure - #---------------------------------------------- - - name: Upload test results on failure - if: failure() - uses: actions/upload-artifact@v4 - with: - name: telemetry-test-results - path: | - .pytest_cache/ - tests-unsafe.log - retention-days: 7 - diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index ad5369997..9c9e30a24 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -54,9 +54,5 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run e2e tests (excluding daily-only tests) - run: | - # Exclude telemetry E2E tests from PR runs (run daily instead) - poetry run python -m pytest tests/e2e \ - --ignore=tests/e2e/test_telemetry_e2e.py \ - -n auto \ No newline at end of file + - name: Run e2e tests + run: poetry run python -m pytest tests/e2e -n auto \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b902e976..1fa6bfb66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,5 @@ # Release History -# 4.2.1 (2025-11-20) -- Ignore transactions by default (databricks/databricks-sql-python#711 by @jayantsing-db) - -# 4.2.0 (2025-11-14) -- Add multi-statement transaction support (databricks/databricks-sql-python#704 by @jayantsing-db) -- Add a workflow to parallelise the E2E tests (databricks/databricks-sql-python#697 by @msrathore-db) -- Bring Python telemetry event model consistent with JDBC (databricks/databricks-sql-python#701 by @nikhilsuri-db) - # 4.1.4 (2025-10-15) - Add support for Token Federation (databricks/databricks-sql-python#691 by @madhav-db) - Add metric view support (databricks/databricks-sql-python#688 by @shivam2680) diff --git a/README.md b/README.md index ec82a3637..d57efda1f 100644 --- a/README.md +++ b/README.md @@ -67,12 +67,6 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789 > to authenticate the target Databricks user account and needs to open the browser for authentication. So it > can only run on the user's machine. -## Transaction Support - -The connector supports multi-statement transactions with manual commit/rollback control. Set `connection.autocommit = False` to disable autocommit mode, then use `connection.commit()` and `connection.rollback()` to control transactions. - -For detailed documentation, examples, and best practices, see **[TRANSACTIONS.md](TRANSACTIONS.md)**. - ## SQLAlchemy Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`. diff --git a/TRANSACTIONS.md b/TRANSACTIONS.md deleted file mode 100644 index 590c298c0..000000000 --- a/TRANSACTIONS.md +++ /dev/null @@ -1,387 +0,0 @@ -# Transaction Support - -The Databricks SQL Connector for Python supports multi-statement transactions (MST). This allows you to group multiple SQL statements into atomic units that either succeed completely or fail completely. - -## Autocommit Behavior - -By default, every SQL statement executes in its own transaction and commits immediately (autocommit mode). This is the standard behavior for most database connectors. - -```python -from databricks import sql - -connection = sql.connect( - server_hostname="your-workspace.cloud.databricks.com", - http_path="/sql/1.0/warehouses/abc123" -) - -# Default: autocommit is True -print(connection.autocommit) # True - -# Each statement commits immediately -cursor = connection.cursor() -cursor.execute("INSERT INTO my_table VALUES (1, 'data')") -# Already committed - data is visible to other connections -``` - -To use explicit transactions, disable autocommit: - -```python -connection.autocommit = False - -# Now statements are grouped into a transaction -cursor = connection.cursor() -cursor.execute("INSERT INTO my_table VALUES (1, 'data')") -# Not committed yet - must call connection.commit() - -connection.commit() # Now it's visible -``` - -## Basic Transaction Operations - -### Committing Changes - -When autocommit is disabled, you must explicitly commit your changes: - -```python -connection.autocommit = False -cursor = connection.cursor() - -try: - cursor.execute("INSERT INTO orders VALUES (1, 100.00)") - cursor.execute("INSERT INTO order_items VALUES (1, 'Widget', 2)") - connection.commit() # Both inserts succeed together -except Exception as e: - connection.rollback() # Neither insert is saved - raise -finally: - connection.autocommit = True # Restore default state -``` - -### Rolling Back Changes - -Use `rollback()` to discard all changes made in the current transaction: - -```python -connection.autocommit = False -cursor = connection.cursor() - -cursor.execute("INSERT INTO accounts VALUES (1, 1000)") -cursor.execute("UPDATE accounts SET balance = balance - 500 WHERE id = 1") - -# Changed your mind? -connection.rollback() # All changes discarded -``` - -Note: Calling `rollback()` when autocommit is enabled is safe (it's a no-op), but calling `commit()` will raise a `TransactionError`. - -### Sequential Transactions - -After a commit or rollback, a new transaction starts automatically: - -```python -connection.autocommit = False - -# First transaction -cursor.execute("INSERT INTO logs VALUES (1, 'event1')") -connection.commit() - -# Second transaction starts automatically -cursor.execute("INSERT INTO logs VALUES (2, 'event2')") -connection.rollback() # Only the second insert is discarded -``` - -## Multi-Table Transactions - -Transactions span multiple tables atomically. Either all changes are committed, or all are rolled back: - -```python -connection.autocommit = False -cursor = connection.cursor() - -try: - # Insert into multiple tables - cursor.execute("INSERT INTO customers VALUES (1, 'Alice')") - cursor.execute("INSERT INTO orders VALUES (1, 1, 100.00)") - cursor.execute("INSERT INTO shipments VALUES (1, 1, 'pending')") - - connection.commit() # All three inserts succeed atomically -except Exception as e: - connection.rollback() # All three inserts are discarded - raise -finally: - connection.autocommit = True # Restore default state -``` - -This is particularly useful for maintaining data consistency across related tables. - -## Transaction Isolation - -Databricks uses **Snapshot Isolation** (mapped to `REPEATABLE_READ` in standard SQL terminology). This means: - -- **Repeatable reads**: Once you read data in a transaction, subsequent reads will see the same data (even if other transactions modify it) -- **Atomic commits**: Changes are visible to other connections only after commit -- **Write serializability within a single table**: Concurrent writes to the same table will cause conflicts -- **Snapshot isolation across tables**: Concurrent writes to different tables can succeed - -### Getting the Isolation Level - -```python -level = connection.get_transaction_isolation() -print(level) # Output: REPEATABLE_READ -``` - -### Setting the Isolation Level - -Currently, only `REPEATABLE_READ` is supported: - -```python -from databricks import sql - -# Using the constant -connection.set_transaction_isolation(sql.TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ) - -# Or using a string -connection.set_transaction_isolation("REPEATABLE_READ") - -# Other levels will raise NotSupportedError -connection.set_transaction_isolation("READ_COMMITTED") # Raises NotSupportedError -``` - -### What Repeatable Read Means in Practice - -Within a transaction, you'll always see a consistent snapshot of the data: - -```python -connection.autocommit = False -cursor = connection.cursor() - -# First read -cursor.execute("SELECT balance FROM accounts WHERE id = 1") -balance1 = cursor.fetchone()[0] # Returns 1000 - -# Another connection updates the balance -# (In a separate connection: UPDATE accounts SET balance = 500 WHERE id = 1) - -# Second read in the same transaction -cursor.execute("SELECT balance FROM accounts WHERE id = 1") -balance2 = cursor.fetchone()[0] # Still returns 1000 (repeatable read!) - -connection.commit() - -# After commit, new transactions will see the updated value (500) -``` - -## Error Handling - -### Setting Autocommit During a Transaction - -You cannot change autocommit mode while a transaction is active: - -```python -connection.autocommit = False -cursor = connection.cursor() - -try: - cursor.execute("INSERT INTO logs VALUES (1, 'data')") - - # This will raise TransactionError - connection.autocommit = True # Error: transaction is active - -except sql.TransactionError as e: - print(f"Cannot change autocommit: {e}") - connection.rollback() # Clean up the transaction -finally: - connection.autocommit = True # Now it's safe to restore -``` - -### Committing Without an Active Transaction - -If autocommit is enabled, there's no active transaction, so calling `commit()` will fail: - -```python -connection.autocommit = True # Default - -try: - connection.commit() # Raises TransactionError -except sql.TransactionError as e: - print(f"No active transaction: {e}") -``` - -However, `rollback()` is safe in this case (it's a no-op). - -### Recovering from Query Failures - -If a statement fails during a transaction, roll back and start a new transaction: - -```python -connection.autocommit = False -cursor = connection.cursor() - -try: - cursor.execute("INSERT INTO valid_table VALUES (1, 'data')") - cursor.execute("INSERT INTO nonexistent_table VALUES (2, 'data')") # Fails - connection.commit() -except Exception as e: - connection.rollback() # Discard the partial transaction - - # Log the error (with autocommit still disabled) - try: - cursor.execute("INSERT INTO error_log VALUES (1, 'Query failed')") - connection.commit() - except Exception: - connection.rollback() -finally: - connection.autocommit = True # Restore default state -``` - -## Querying Server State - -By default, the `autocommit` property returns a cached value for performance. If you need to query the server each time (for instance, when strong consistency is required): - -```python -connection = sql.connect( - server_hostname="your-workspace.cloud.databricks.com", - http_path="/sql/1.0/warehouses/abc123", - fetch_autocommit_from_server=True -) - -# Each access queries the server -state = connection.autocommit # Executes "SET AUTOCOMMIT" query -``` - -This is generally not needed for normal usage. - -## Write Conflicts - -### Within a Single Table - -Databricks enforces **write serializability** within a single table. If two transactions try to modify the same table concurrently, one will fail: - -```python -# Connection 1 -conn1.autocommit = False -cursor1 = conn1.cursor() -cursor1.execute("INSERT INTO accounts VALUES (1, 100)") - -# Connection 2 (concurrent) -conn2.autocommit = False -cursor2 = conn2.cursor() -cursor2.execute("INSERT INTO accounts VALUES (2, 200)") - -# First commit succeeds -conn1.commit() # OK - -# Second commit fails with concurrent write conflict -try: - conn2.commit() # Raises error about concurrent writes -except Exception as e: - conn2.rollback() - print(f"Concurrent write detected: {e}") -``` - -This happens even when the rows being modified are different. The conflict detection is at the table level. - -### Across Multiple Tables - -Concurrent writes to *different* tables can succeed. Each table tracks its own write conflicts independently: - -```python -# Connection 1: writes to table_a -conn1.autocommit = False -cursor1 = conn1.cursor() -cursor1.execute("INSERT INTO table_a VALUES (1, 'data')") - -# Connection 2: writes to table_b (different table) -conn2.autocommit = False -cursor2 = conn2.cursor() -cursor2.execute("INSERT INTO table_b VALUES (1, 'data')") - -# Both commits succeed (different tables) -conn1.commit() # OK -conn2.commit() # Also OK -``` - -## Best Practices - -1. **Keep transactions short**: Long-running transactions can cause conflicts with other connections. Commit as soon as your atomic unit of work is complete. - -2. **Always handle exceptions**: Wrap transaction code in try/except/finally and call `rollback()` on errors. - -```python -connection.autocommit = False -cursor = connection.cursor() - -try: - cursor.execute("INSERT INTO table1 VALUES (1, 'data')") - cursor.execute("UPDATE table2 SET status = 'updated'") - connection.commit() -except Exception as e: - connection.rollback() - logger.error(f"Transaction failed: {e}") - raise -finally: - connection.autocommit = True # Restore default state -``` - -3. **Use context managers**: If you're writing helper functions, consider using a context manager pattern: - -```python -from contextlib import contextmanager - -@contextmanager -def transaction(connection): - connection.autocommit = False - try: - yield connection - connection.commit() - except Exception: - connection.rollback() - raise - finally: - connection.autocommit = True - -# Usage -with transaction(connection): - cursor = connection.cursor() - cursor.execute("INSERT INTO logs VALUES (1, 'message')") - # Auto-commits on success, auto-rolls back on exception -``` - -4. **Reset autocommit when done**: Use a `finally` block to restore autocommit to `True`. This is especially important if the connection is reused or part of a connection pool: - -```python -connection.autocommit = False -try: - # ... transaction code ... - connection.commit() -except Exception: - connection.rollback() - raise -finally: - connection.autocommit = True # Restore to default state -``` - -5. **Be aware of isolation semantics**: Remember that repeatable read means you see a snapshot from the start of your transaction. If you need to see recent changes from other transactions, commit your current transaction and start a new one. - -## Requirements - -To use transactions, you need: -- A Databricks SQL warehouse that supports Multi-Statement Transactions (MST) -- Tables created with the `delta.feature.catalogOwned-preview` table property: - -```sql -CREATE TABLE my_table (id INT, value STRING) -USING DELTA -TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') -``` - -## Related APIs - -- `connection.autocommit` - Get or set autocommit mode (boolean) -- `connection.commit()` - Commit the current transaction -- `connection.rollback()` - Roll back the current transaction -- `connection.get_transaction_isolation()` - Get the isolation level (returns `"REPEATABLE_READ"`) -- `connection.set_transaction_isolation(level)` - Validate/set isolation level (only `"REPEATABLE_READ"` supported) -- `sql.TransactionError` - Exception raised for transaction-specific errors - -All of these are extensions to [PEP 249](https://www.python.org/dev/peps/pep-0249/) (Python Database API Specification v2.0). diff --git a/examples/README.md b/examples/README.md index f52dede1d..d73c58a6b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,7 +31,6 @@ To run all of these examples you can clone the entire repository to your disk. O - **`query_execute.py`** connects to the `samples` database of your default catalog, runs a small query, and prints the result to screen. - **`insert_data.py`** adds a tables called `squares` to your default catalog and inserts one hundred rows of example data. Then it fetches this data and prints it to the screen. -- **`transactions.py`** demonstrates multi-statement transaction support with explicit commit/rollback control. Shows how to group multiple SQL statements into an atomic unit that either succeeds completely or fails completely. - **`query_cancel.py`** shows how to cancel a query assuming that you can access the `Cursor` executing that query from a different thread. This is necessary because `databricks-sql-connector` does not yet implement an asynchronous API; calling `.execute()` blocks the current thread until execution completes. Therefore, the connector can't cancel queries from the same thread where they began. - **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script. - **`m2m_oauth.py`** shows the simplest example of authenticating by using OAuth M2M (machine-to-machine) for service principal. diff --git a/examples/transactions.py b/examples/transactions.py deleted file mode 100644 index 6f58dbd2d..000000000 --- a/examples/transactions.py +++ /dev/null @@ -1,47 +0,0 @@ -from databricks import sql -import os - -with sql.connect( - server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path=os.getenv("DATABRICKS_HTTP_PATH"), - access_token=os.getenv("DATABRICKS_TOKEN"), -) as connection: - - # Disable autocommit to use explicit transactions - connection.autocommit = False - - with connection.cursor() as cursor: - try: - # Create tables for demonstration - cursor.execute("CREATE TABLE IF NOT EXISTS accounts (id int, balance int)") - cursor.execute( - "CREATE TABLE IF NOT EXISTS transfers (from_id int, to_id int, amount int)" - ) - connection.commit() - - # Start a new transaction - transfer money between accounts - cursor.execute("INSERT INTO accounts VALUES (1, 1000), (2, 500)") - cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1") - cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2") - cursor.execute("INSERT INTO transfers VALUES (1, 2, 100)") - - # Commit the transaction - all changes succeed together - connection.commit() - print("Transaction committed successfully") - - # Verify the results - cursor.execute("SELECT * FROM accounts ORDER BY id") - print("Accounts:", cursor.fetchall()) - - cursor.execute("SELECT * FROM transfers") - print("Transfers:", cursor.fetchall()) - - except Exception as e: - # Roll back on error - all changes are discarded - connection.rollback() - print(f"Transaction rolled back due to error: {e}") - raise - - finally: - # Restore autocommit to default state - connection.autocommit = True diff --git a/poetry.lock b/poetry.lock index 193efa109..1a8074c2a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -1348,38 +1348,6 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] -[[package]] -name = "pybreaker" -version = "1.2.0" -description = "Python implementation of the Circuit Breaker pattern" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.10\"" -files = [ - {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, - {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, -] - -[package.extras] -test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] - -[[package]] -name = "pybreaker" -version = "1.4.1" -description = "Python implementation of the Circuit Breaker pattern" -optional = false -python-versions = ">=3.9" -groups = ["main"] -markers = "python_version >= \"3.10\"" -files = [ - {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, - {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, -] - -[package.extras] -test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] - [[package]] name = "pycparser" version = "2.22" @@ -1890,4 +1858,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" +content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" diff --git a/pyproject.toml b/pyproject.toml index 61c248e98..c0eb8244d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.1" +version = "4.1.4" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -26,7 +26,6 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" -pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index cd37e6ce1..403a4d130 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -8,9 +8,6 @@ paramstyle = "named" -# Transaction isolation level constants (extension to PEP 249) -TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" - import re from typing import TYPE_CHECKING @@ -71,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.1" +__version__ = "4.1.4" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index a764b036d..3e0be0d2b 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,7 +51,6 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, - telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -84,7 +83,6 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent - self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 9c5c63033..5f10f2df4 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -157,7 +157,7 @@ def __init__( "_use_arrow_native_complex_types", True ) - self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", False) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) # Extract warehouse ID from http_path diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a7f802dcd..5bb191ca2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,7 +9,6 @@ import json import os import decimal -from urllib.parse import urlparse from uuid import UUID from databricks.sql import __version__ @@ -21,8 +20,6 @@ InterfaceError, NotSupportedError, ProgrammingError, - TransactionError, - DatabaseError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -89,9 +86,6 @@ NO_NATIVE_PARAMS: List = [] -# Transaction isolation level constants (extension to PEP 249) -TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" - class Connection: def __init__( @@ -104,7 +98,6 @@ def __init__( catalog: Optional[str] = None, schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, - ignore_transactions: bool = True, **kwargs, ) -> None: """ @@ -213,17 +206,6 @@ def read(self) -> Optional[OAuthToken]: This allows 1. cursor.tables() to return METRIC_VIEW table type 2. cursor.columns() to return "measure" column type - :param fetch_autocommit_from_server: `bool`, optional (default is False) - When True, the connection.autocommit property queries the server for current state - using SET AUTOCOMMIT instead of returning cached value. - Set to True if autocommit might be changed by external means (e.g., external SQL commands). - When False (default), uses cached state for better performance. - :param ignore_transactions: `bool`, optional (default is True) - When True, transaction-related operations behave as follows: - - commit(): no-op (does nothing) - - rollback(): raises NotSupportedError - - autocommit setter: no-op (does nothing) - When False, transaction operations execute normally. """ # Internal arguments in **kwargs: @@ -322,10 +304,6 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) - self._fetch_autocommit_from_server = kwargs.get( - "fetch_autocommit_from_server", False - ) - self.ignore_transactions = ignore_transactions self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -344,20 +322,6 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex() ) - # Determine proxy usage - use_proxy = self.http_client.using_proxy() - proxy_host_info = None - if ( - use_proxy - and self.http_client.proxy_uri - and isinstance(self.http_client.proxy_uri, str) - ): - parsed = urlparse(self.http_client.proxy_uri) - proxy_host_info = HostDetails( - host_url=parsed.hostname or self.http_client.proxy_uri, - port=parsed.port or 8080, - ) - driver_connection_params = DriverConnectionParameters( http_path=http_path, mode=DatabricksClientType.SEA @@ -367,31 +331,13 @@ def read(self) -> Optional[OAuthToken]: auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), socket_timeout=kwargs.get("_socket_timeout", None), - azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id", None), - azure_tenant_id=kwargs.get("azure_tenant_id", None), - use_proxy=use_proxy, - use_system_proxy=use_proxy, - proxy_host_info=proxy_host_info, - use_cf_proxy=False, # CloudFlare proxy not yet supported in Python - cf_proxy_host_info=None, # CloudFlare proxy not yet supported in Python - non_proxy_hosts=None, - allow_self_signed_support=kwargs.get("_tls_no_verify", False), - use_system_trust_store=True, # Python uses system SSL by default - enable_arrow=pyarrow is not None, - enable_direct_results=True, # Always enabled in Python - enable_sea_hybrid_results=kwargs.get("use_hybrid_disposition", False), - http_connection_pool_size=kwargs.get("pool_maxsize", None), - rows_fetched_per_block=DEFAULT_ARRAY_SIZE, - async_poll_interval_millis=2000, # Default polling interval - support_many_parameters=True, # Native parameters supported - enable_complex_datatype_support=_use_arrow_native_complex_types, - allowed_volume_ingestion_paths=self.staging_allowed_local_path, ) self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, ) + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -527,286 +473,15 @@ def _close(self, close_cursors=True) -> None: if self.http_client: self.http_client.close() - @property - def autocommit(self) -> bool: - """ - Get auto-commit mode for this connection. - - Extension to PEP 249. Returns cached value by default. - If fetch_autocommit_from_server=True was set during connection, - queries server for current state. - - Returns: - bool: True if auto-commit is enabled, False otherwise - - Raises: - InterfaceError: If connection is closed - TransactionError: If fetch_autocommit_from_server=True and query fails - """ - if not self.open: - raise InterfaceError( - "Cannot get autocommit on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - if self._fetch_autocommit_from_server: - return self._fetch_autocommit_state_from_server() - - return self.session.get_autocommit() - - @autocommit.setter - def autocommit(self, value: bool) -> None: - """ - Set auto-commit mode for this connection. - - Extension to PEP 249. Executes SET AUTOCOMMIT command on server. - - Args: - value: True to enable auto-commit, False to disable - - When ignore_transactions is True: - - This method is a no-op (does nothing) - - Raises: - InterfaceError: If connection is closed - TransactionError: If server rejects the change - """ - # No-op when ignore_transactions is True - if self.ignore_transactions: - return - - if not self.open: - raise InterfaceError( - "Cannot set autocommit on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - # Create internal cursor for transaction control - cursor = None - try: - cursor = self.cursor() - sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}" - cursor.execute(sql) - - # Update cached state on success - self.session.set_autocommit(value) - - except DatabaseError as e: - # Wrap in TransactionError with context - raise TransactionError( - f"Failed to set autocommit to {value}: {e.message}", - context={ - **e.context, - "operation": "set_autocommit", - "autocommit_value": value, - }, - session_id_hex=self.get_session_id_hex(), - ) from e - finally: - if cursor: - cursor.close() - - def _fetch_autocommit_state_from_server(self) -> bool: - """ - Query server for current autocommit state using SET AUTOCOMMIT. - - Returns: - bool: Server's autocommit state - - Raises: - TransactionError: If query fails - """ - cursor = None - try: - cursor = self.cursor() - cursor.execute("SET AUTOCOMMIT") - - # Fetch result: should return row with value column - result = cursor.fetchone() - if result is None: - raise TransactionError( - "No result returned from SET AUTOCOMMIT query", - context={"operation": "fetch_autocommit"}, - session_id_hex=self.get_session_id_hex(), - ) - - # Parse value (first column should be "true" or "false") - value_str = str(result[0]).lower() - autocommit_state = value_str == "true" - - # Update cache - self.session.set_autocommit(autocommit_state) - - return autocommit_state - - except TransactionError: - # Re-raise TransactionError as-is - raise - except DatabaseError as e: - # Wrap other DatabaseErrors - raise TransactionError( - f"Failed to fetch autocommit state from server: {e.message}", - context={**e.context, "operation": "fetch_autocommit"}, - session_id_hex=self.get_session_id_hex(), - ) from e - finally: - if cursor: - cursor.close() - - def commit(self) -> None: - """ - Commit the current transaction. - - Per PEP 249. Should be called only when autocommit is disabled. - - When autocommit is False: - - Commits the current transaction - - Server automatically starts new transaction - - When autocommit is True: - - Server may throw error if no active transaction - - When ignore_transactions is True: - - This method is a no-op (does nothing) - - Raises: - InterfaceError: If connection is closed - TransactionError: If commit fails (e.g., no active transaction) - """ - # No-op when ignore_transactions is True - if self.ignore_transactions: - return - - if not self.open: - raise InterfaceError( - "Cannot commit on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - cursor = None - try: - cursor = self.cursor() - cursor.execute("COMMIT") - - except DatabaseError as e: - raise TransactionError( - f"Failed to commit transaction: {e.message}", - context={**e.context, "operation": "commit"}, - session_id_hex=self.get_session_id_hex(), - ) from e - finally: - if cursor: - cursor.close() - - def rollback(self) -> None: - """ - Rollback the current transaction. - - Per PEP 249. Should be called only when autocommit is disabled. - - When autocommit is False: - - Rolls back the current transaction - - Server automatically starts new transaction - - When autocommit is True: - - ROLLBACK is forgiving (no-op, doesn't throw exception) - - When ignore_transactions is True: - - Raises NotSupportedError - - Note: ROLLBACK is safe to call even without active transaction. - - Raises: - InterfaceError: If connection is closed - NotSupportedError: If ignore_transactions is True - TransactionError: If rollback fails - """ - # Raise NotSupportedError when ignore_transactions is True - if self.ignore_transactions: - raise NotSupportedError( - "Transactions are not supported on Databricks", - session_id_hex=self.get_session_id_hex(), - ) - - if not self.open: - raise InterfaceError( - "Cannot rollback on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - cursor = None - try: - cursor = self.cursor() - cursor.execute("ROLLBACK") - - except DatabaseError as e: - raise TransactionError( - f"Failed to rollback transaction: {e.message}", - context={**e.context, "operation": "rollback"}, - session_id_hex=self.get_session_id_hex(), - ) from e - finally: - if cursor: - cursor.close() - - def get_transaction_isolation(self) -> str: - """ - Get the transaction isolation level. - - Extension to PEP 249. - - Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation), - which is the default and only supported level. - - Returns: - str: "REPEATABLE_READ" - the transaction isolation level constant - - Raises: - InterfaceError: If connection is closed - """ - if not self.open: - raise InterfaceError( - "Cannot get transaction isolation on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ - - def set_transaction_isolation(self, level: str) -> None: - """ - Set transaction isolation level. - - Extension to PEP 249. - - Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation). - This method validates that the requested level is supported but does not - execute any SQL, as REPEATABLE_READ is the default server behavior. - - Args: - level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ" - (case-insensitive, underscores and spaces are interchangeable) - - Raises: - InterfaceError: If connection is closed - NotSupportedError: If isolation level not supported - """ - if not self.open: - raise InterfaceError( - "Cannot set transaction isolation on closed connection", - session_id_hex=self.get_session_id_hex(), - ) - - # Normalize and validate isolation level - normalized_level = level.upper().replace("_", " ") + def commit(self): + """No-op because Databricks does not support transactions""" + pass - if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace( - "_", " " - ): - raise NotSupportedError( - f"Setting transaction isolation level '{level}' is not supported. " - f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", - session_id_hex=self.get_session_id_hex(), - ) + def rollback(self): + raise NotSupportedError( + "Transactions are not supported on Databricks", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 032701f63..8a1cf5bd5 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -165,9 +165,8 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: cls._initialize() assert cls._executor is not None - # Cache at HOST level - share feature flags across connections to same host - # Feature flags are per-host, not per-session - key = connection.session.host + # Use the unique session ID as the key + key = connection.get_session_id_hex() if key not in cls._context_map: cls._context_map[key] = FeatureFlagsContext( connection, cls._executor, connection.session.http_client @@ -178,8 +177,7 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: def remove_instance(cls, connection: "Connection"): """Removes the context for a given connection and shuts down the executor if no clients remain.""" with cls._lock: - # Use host as key to match get_instance - key = connection.session.host + key = connection.get_session_id_hex() if key in cls._context_map: cls._context_map.pop(key, None) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index d5f7d3c8d..7ccd69c54 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -28,42 +28,6 @@ logger = logging.getLogger(__name__) -def _extract_http_status_from_max_retry_error(e: MaxRetryError) -> Optional[int]: - """ - Extract HTTP status code from MaxRetryError if available. - - urllib3 structures MaxRetryError in different ways depending on the failure scenario: - - e.reason.response.status: Most common case when retries are exhausted - - e.response.status: Alternate structure in some scenarios - - Args: - e: MaxRetryError exception from urllib3 - - Returns: - HTTP status code as int if found, None otherwise - """ - # Try primary structure: e.reason.response.status - if ( - hasattr(e, "reason") - and e.reason is not None - and hasattr(e.reason, "response") - and e.reason.response is not None - ): - http_code = getattr(e.reason.response, "status", None) - if http_code is not None: - return http_code - - # Try alternate structure: e.response.status - if ( - hasattr(e, "response") - and e.response is not None - and hasattr(e.response, "status") - ): - return e.response.status - - return None - - class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. @@ -300,16 +264,7 @@ def request_context( yield response except MaxRetryError as e: logger.error("HTTP request failed after retries: %s", e) - - # Extract HTTP status code from MaxRetryError if available - http_code = _extract_http_status_from_max_retry_error(e) - - context = {} - if http_code is not None: - context["http-code"] = http_code - logger.error("HTTP request failed with status code: %d", http_code) - - raise RequestError(f"HTTP request failed: {e}", context=context) + raise RequestError(f"HTTP request failed: {e}") except Exception as e: logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") @@ -346,11 +301,6 @@ def using_proxy(self) -> bool: """Check if proxy support is available (not whether it's being used for a specific request).""" return self._proxy_pool_manager is not None - @property - def proxy_uri(self) -> Optional[str]: - """Get the configured proxy URI, if any.""" - return self._proxy_uri - def close(self): """Close the underlying connection pools.""" if self._direct_pool_manager: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 24844d573..4a772c49b 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -70,23 +70,6 @@ class NotSupportedError(DatabaseError): pass -class TransactionError(DatabaseError): - """ - Exception raised for transaction-specific errors. - - This exception is used when transaction control operations fail, such as: - - Setting autocommit mode (AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION) - - Committing a transaction (MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION) - - Rolling back a transaction - - Setting transaction isolation level - - The exception includes context about which transaction operation failed - and preserves the underlying cause via exception chaining. - """ - - pass - - ### Custom error classes ### class InvalidServerResponseError(OperationalError): """Thrown if the server does not set the initial namespace correctly""" @@ -143,24 +126,3 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" - - -class TelemetryRateLimitError(Exception): - """Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable. - This exception is used exclusively by the circuit breaker to track telemetry rate limiting events.""" - - -class TelemetryNonRateLimitError(Exception): - """Wrapper for telemetry errors that should NOT trigger circuit breaker. - - This exception wraps non-rate-limiting errors (network errors, timeouts, server errors, etc.) - and is excluded from circuit breaker failure counting. Only TelemetryRateLimitError should - open the circuit breaker. - - Attributes: - original_exception: The actual exception that occurred - """ - - def __init__(self, original_exception: Exception): - self.original_exception = original_exception - super().__init__(f"Non-rate-limit telemetry error: {original_exception}") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 0f723d144..d8ba5d125 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -45,9 +45,6 @@ def __init__( self.schema = schema self.http_path = http_path - # Initialize autocommit state (JDBC default is True) - self._autocommit = True - user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -171,24 +168,6 @@ def guid_hex(self) -> str: """Get the session ID in hex format""" return self._session_id.hex_guid - def get_autocommit(self) -> bool: - """ - Get the cached autocommit state for this session. - - Returns: - bool: True if autocommit is enabled, False otherwise - """ - return self._autocommit - - def set_autocommit(self, value: bool) -> None: - """ - Update the cached autocommit state for this session. - - Args: - value: True to cache autocommit as enabled, False as disabled - """ - self._autocommit = value - def close(self) -> None: """Close the underlying session.""" logger.info("Closing session %s", self.guid_hex) diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py deleted file mode 100644 index 852f0d916..000000000 --- a/src/databricks/sql/telemetry/circuit_breaker_manager.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Circuit breaker implementation for telemetry requests. - -This module provides circuit breaker functionality to prevent telemetry failures -from impacting the main SQL operations. It uses pybreaker library to implement -the circuit breaker pattern. -""" - -import logging -import threading -from typing import Dict - -import pybreaker -from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener - -from databricks.sql.exc import TelemetryNonRateLimitError - -logger = logging.getLogger(__name__) - -# Circuit Breaker Constants -MINIMUM_CALLS = 20 # Number of failures before circuit opens -RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit -NAME_PREFIX = "telemetry-circuit-breaker" - -# Circuit Breaker State Constants (used in logging) -CIRCUIT_BREAKER_STATE_OPEN = "open" -CIRCUIT_BREAKER_STATE_CLOSED = "closed" -CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" - -# Logging Message Constants -LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" -LOG_CIRCUIT_BREAKER_OPENED = ( - "Circuit breaker opened for %s - telemetry requests will be blocked" -) -LOG_CIRCUIT_BREAKER_CLOSED = ( - "Circuit breaker closed for %s - telemetry requests will be allowed" -) -LOG_CIRCUIT_BREAKER_HALF_OPEN = ( - "Circuit breaker half-open for %s - testing telemetry requests" -) - - -class CircuitBreakerStateListener(CircuitBreakerListener): - """Listener for circuit breaker state changes.""" - - def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: - """Called before the circuit breaker calls a function.""" - pass - - def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: - """Called when a function called by the circuit breaker fails.""" - pass - - def success(self, cb: CircuitBreaker) -> None: - """Called when a function called by the circuit breaker succeeds.""" - pass - - def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: - """Called when the circuit breaker state changes.""" - old_state_name = old_state.name if old_state else "None" - new_state_name = new_state.name if new_state else "None" - - logger.info( - LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name - ) - - if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: - logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) - elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: - logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) - elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: - logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) - - -class CircuitBreakerManager: - """ - Manages circuit breaker instances for telemetry requests. - - Creates and caches circuit breaker instances per host to ensure telemetry - failures don't impact main SQL operations. - """ - - _instances: Dict[str, CircuitBreaker] = {} - _lock = threading.RLock() - - @classmethod - def get_circuit_breaker(cls, host: str) -> CircuitBreaker: - """ - Get or create a circuit breaker instance for the specified host. - - Args: - host: The hostname for which to get the circuit breaker - - Returns: - CircuitBreaker instance for the host - """ - with cls._lock: - if host not in cls._instances: - breaker = CircuitBreaker( - fail_max=MINIMUM_CALLS, - reset_timeout=RESET_TIMEOUT, - name=f"{NAME_PREFIX}-{host}", - exclude=[ - TelemetryNonRateLimitError - ], # Don't count these as failures - ) - # Add state change listener for logging - breaker.add_listener(CircuitBreakerStateListener()) - cls._instances[host] = breaker - logger.debug("Created circuit breaker for host: %s", host) - - return cls._instances[host] diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 36ebee2b8..12cacd851 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,6 +1,6 @@ import time import functools -from typing import Optional, Dict, Any +from typing import Optional import logging from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( @@ -11,141 +11,127 @@ logger = logging.getLogger(__name__) -def _extract_cursor_data(cursor) -> Dict[str, Any]: +class TelemetryExtractor: """ - Extract telemetry data directly from a Cursor object. + Base class for extracting telemetry information from various object types. - OPTIMIZATION: Uses direct attribute access instead of wrapper objects. - This eliminates object creation overhead and method call indirection. + This class serves as a proxy that delegates attribute access to the wrapped object + while providing a common interface for extracting telemetry-related data. + """ - Args: - cursor: The Cursor object to extract data from + def __init__(self, obj): + self._obj = obj - Returns: - Dict with telemetry data (values may be None if extraction fails) + def __getattr__(self, name): + return getattr(self._obj, name) + + def get_session_id_hex(self): + pass + + def get_statement_id(self): + pass + + def get_is_compressed(self): + pass + + def get_execution_result_format(self): + pass + + def get_retry_count(self): + pass + + def get_chunk_id(self): + pass + + +class CursorExtractor(TelemetryExtractor): """ - data = {} - - # Extract statement_id (query_id) - direct attribute access - try: - data["statement_id"] = cursor.query_id - except (AttributeError, Exception): - data["statement_id"] = None - - # Extract session_id_hex - direct method call - try: - data["session_id_hex"] = cursor.connection.get_session_id_hex() - except (AttributeError, Exception): - data["session_id_hex"] = None - - # Extract is_compressed - direct attribute access - try: - data["is_compressed"] = cursor.connection.lz4_compression - except (AttributeError, Exception): - data["is_compressed"] = False - - # Extract execution_result_format - inline logic - try: - if cursor.active_result_set is None: - data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED - else: - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - results = cursor.active_result_set.results - if isinstance(results, ColumnQueue): - data["execution_result"] = ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(results, CloudFetchQueue): - data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(results, ArrowQueue): - data["execution_result"] = ExecutionResultFormat.INLINE_ARROW - else: - data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED - except (AttributeError, Exception): - data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED - - # Extract retry_count - direct attribute access - try: - if hasattr(cursor.backend, "retry_policy") and cursor.backend.retry_policy: - data["retry_count"] = len(cursor.backend.retry_policy.history) - else: - data["retry_count"] = 0 - except (AttributeError, Exception): - data["retry_count"] = 0 - - # chunk_id is always None for Cursor - data["chunk_id"] = None - - return data - - -def _extract_result_set_handler_data(handler) -> Dict[str, Any]: + Telemetry extractor specialized for Cursor objects. + + Extracts telemetry information from database cursor objects, including + statement IDs, session information, compression settings, and result formats. """ - Extract telemetry data directly from a ResultSetDownloadHandler object. - OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + def get_statement_id(self) -> Optional[str]: + return self.query_id - Args: - handler: The ResultSetDownloadHandler object to extract data from + def get_session_id_hex(self) -> Optional[str]: + return self.connection.get_session_id_hex() - Returns: - Dict with telemetry data (values may be None if extraction fails) - """ - data = {} + def get_is_compressed(self) -> bool: + return self.connection.lz4_compression + + def get_execution_result_format(self) -> ExecutionResultFormat: + if self.active_result_set is None: + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + + if isinstance(self.active_result_set.results, ColumnQueue): + return ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(self.active_result_set.results, CloudFetchQueue): + return ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(self.active_result_set.results, ArrowQueue): + return ExecutionResultFormat.INLINE_ARROW + return ExecutionResultFormat.FORMAT_UNSPECIFIED + + def get_retry_count(self) -> int: + if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: + return len(self.backend.retry_policy.history) + return 0 + + def get_chunk_id(self): + return None - # Extract session_id_hex - direct attribute access - try: - data["session_id_hex"] = handler.session_id_hex - except (AttributeError, Exception): - data["session_id_hex"] = None - # Extract statement_id - direct attribute access - try: - data["statement_id"] = handler.statement_id - except (AttributeError, Exception): - data["statement_id"] = None +class ResultSetDownloadHandlerExtractor(TelemetryExtractor): + """ + Telemetry extractor specialized for ResultSetDownloadHandler objects. + """ - # Extract is_compressed - direct attribute access - try: - data["is_compressed"] = handler.settings.is_lz4_compressed - except (AttributeError, Exception): - data["is_compressed"] = False + def get_session_id_hex(self) -> Optional[str]: + return self._obj.session_id_hex - # execution_result is always EXTERNAL_LINKS for result set handlers - data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS + def get_statement_id(self) -> Optional[str]: + return self._obj.statement_id - # retry_count is not available for result set handlers - data["retry_count"] = None + def get_is_compressed(self) -> bool: + return self._obj.settings.is_lz4_compressed - # Extract chunk_id - direct attribute access - try: - data["chunk_id"] = handler.chunk_id - except (AttributeError, Exception): - data["chunk_id"] = None + def get_execution_result_format(self) -> ExecutionResultFormat: + return ExecutionResultFormat.EXTERNAL_LINKS + + def get_retry_count(self) -> Optional[int]: + # standard requests and urllib3 libraries don't expose retry count + return None - return data + def get_chunk_id(self) -> Optional[int]: + return self._obj.chunk_id -def _extract_telemetry_data(obj) -> Optional[Dict[str, Any]]: +def get_extractor(obj): """ - Extract telemetry data from an object based on its type. + Factory function to create the appropriate telemetry extractor for an object. - OPTIMIZATION: Returns a simple dict instead of creating wrapper objects. - This dict will be used to create the SqlExecutionEvent in the background thread. + Determines the object type and returns the corresponding specialized extractor + that can extract telemetry information from that object type. Args: - obj: The object to extract data from (Cursor, ResultSetDownloadHandler, etc.) + obj: The object to create an extractor for. Can be a Cursor, + ResultSetDownloadHandler, or any other object. Returns: - Dict with telemetry data, or None if object type is not supported + TelemetryExtractor: A specialized extractor instance: + - CursorExtractor for Cursor objects + - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects + - None for all other objects """ - obj_type = obj.__class__.__name__ - - if obj_type == "Cursor": - return _extract_cursor_data(obj) - elif obj_type == "ResultSetDownloadHandler": - return _extract_result_set_handler_data(obj) + if obj.__class__.__name__ == "Cursor": + return CursorExtractor(obj) + elif obj.__class__.__name__ == "ResultSetDownloadHandler": + return ResultSetDownloadHandlerExtractor(obj) else: - logger.debug("No telemetry extraction available for %s", obj_type) + logger.debug("No extractor found for %s", obj.__class__.__name__) return None @@ -157,6 +143,12 @@ def log_latency(statement_type: StatementType = StatementType.NONE): data about the operation, including latency, statement information, and execution context. + The decorator automatically: + - Measures execution time using high-precision performance counters + - Extracts telemetry information from the method's object (self) + - Creates a SqlExecutionEvent with execution details + - Sends the telemetry data asynchronously via TelemetryClient + Args: statement_type (StatementType): The type of SQL statement being executed. @@ -170,49 +162,54 @@ def execute(self, query): function: A decorator that wraps methods to add latency logging. Note: - The wrapped method's object (self) must be a Cursor or - ResultSetDownloadHandler for telemetry data extraction. + The wrapped method's object (self) must be compatible with the + telemetry extractor system (e.g., Cursor or ResultSet objects). """ def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.monotonic() + start_time = time.perf_counter() + result = None try: - return func(self, *args, **kwargs) + result = func(self, *args, **kwargs) + return result finally: - duration_ms = int((time.monotonic() - start_time) * 1000) - - # Always log for debugging - logger.debug("%s completed in %dms", func.__name__, duration_ms) - - # Fast check: use cached telemetry_enabled flag from connection - # Avoids dictionary lookup + instance check on every operation - connection = getattr(self, "connection", None) - if connection and getattr(connection, "telemetry_enabled", False): - session_id_hex = connection.get_session_id_hex() - if session_id_hex: - # Telemetry enabled - extract and send - telemetry_data = _extract_telemetry_data(self) - if telemetry_data: - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=telemetry_data.get("is_compressed"), - execution_result=telemetry_data.get("execution_result"), - retry_count=telemetry_data.get("retry_count"), - chunk_id=telemetry_data.get("chunk_id"), - ) - - telemetry_client = ( - TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=telemetry_data.get("statement_id"), - ) + + def _safe_call(func_to_call): + """Calls a function and returns a default value on any exception.""" + try: + return func_to_call() + except Exception: + return None + + end_time = time.perf_counter() + duration_ms = int((end_time - start_time) * 1000) + + extractor = get_extractor(self) + + if extractor is not None: + session_id_hex = _safe_call(extractor.get_session_id_hex) + statement_id = _safe_call(extractor.get_statement_id) + + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=_safe_call(extractor.get_is_compressed), + execution_result=_safe_call( + extractor.get_execution_result_format + ), + retry_count=_safe_call(extractor.get_retry_count), + chunk_id=_safe_call(extractor.get_chunk_id), + ) + + telemetry_client = TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=statement_id, + ) return wrapper diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index 2e6f63a6f..c7f9d9d17 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -38,25 +38,6 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech (AuthMech): The authentication mechanism used auth_flow (AuthFlow): The authentication flow type socket_timeout (int): Connection timeout in milliseconds - azure_workspace_resource_id (str): Azure workspace resource ID - azure_tenant_id (str): Azure tenant ID - use_proxy (bool): Whether proxy is being used - use_system_proxy (bool): Whether system proxy is being used - proxy_host_info (HostDetails): Proxy host details if configured - use_cf_proxy (bool): Whether CloudFlare proxy is being used - cf_proxy_host_info (HostDetails): CloudFlare proxy host details if configured - non_proxy_hosts (list): List of hosts that bypass proxy - allow_self_signed_support (bool): Whether self-signed certificates are allowed - use_system_trust_store (bool): Whether system trust store is used - enable_arrow (bool): Whether Arrow format is enabled - enable_direct_results (bool): Whether direct results are enabled - enable_sea_hybrid_results (bool): Whether SEA hybrid results are enabled - http_connection_pool_size (int): HTTP connection pool size - rows_fetched_per_block (int): Number of rows fetched per block - async_poll_interval_millis (int): Async polling interval in milliseconds - support_many_parameters (bool): Whether many parameters are supported - enable_complex_datatype_support (bool): Whether complex datatypes are supported - allowed_volume_ingestion_paths (str): Allowed paths for volume ingestion """ http_path: str @@ -65,25 +46,6 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech: Optional[AuthMech] = None auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None - azure_workspace_resource_id: Optional[str] = None - azure_tenant_id: Optional[str] = None - use_proxy: Optional[bool] = None - use_system_proxy: Optional[bool] = None - proxy_host_info: Optional[HostDetails] = None - use_cf_proxy: Optional[bool] = None - cf_proxy_host_info: Optional[HostDetails] = None - non_proxy_hosts: Optional[list] = None - allow_self_signed_support: Optional[bool] = None - use_system_trust_store: Optional[bool] = None - enable_arrow: Optional[bool] = None - enable_direct_results: Optional[bool] = None - enable_sea_hybrid_results: Optional[bool] = None - http_connection_pool_size: Optional[int] = None - rows_fetched_per_block: Optional[int] = None - async_poll_interval_millis: Optional[int] = None - support_many_parameters: Optional[bool] = None - enable_complex_datatype_support: Optional[bool] = None - allowed_volume_ingestion_paths: Optional[str] = None @dataclass @@ -149,69 +111,6 @@ class DriverErrorInfo(JsonSerializableMixin): stack_trace: str -@dataclass -class ChunkDetails(JsonSerializableMixin): - """ - Contains detailed metrics about chunk downloads during result fetching. - - These metrics are accumulated across all chunk downloads for a single statement. - - Attributes: - initial_chunk_latency_millis (int): Latency of the first chunk download - slowest_chunk_latency_millis (int): Latency of the slowest chunk download - total_chunks_present (int): Total number of chunks available - total_chunks_iterated (int): Number of chunks actually downloaded - sum_chunks_download_time_millis (int): Total time spent downloading all chunks - """ - - initial_chunk_latency_millis: Optional[int] = None - slowest_chunk_latency_millis: Optional[int] = None - total_chunks_present: Optional[int] = None - total_chunks_iterated: Optional[int] = None - sum_chunks_download_time_millis: Optional[int] = None - - -@dataclass -class ResultLatency(JsonSerializableMixin): - """ - Contains latency metrics for different phases of query execution. - - This tracks two distinct phases: - 1. result_set_ready_latency_millis: Time from query submission until results are available (execute phase) - - Set when execute() completes - 2. result_set_consumption_latency_millis: Time spent iterating/fetching results (fetch phase) - - Measured from first fetch call until no more rows available - - In Java: tracked via markResultSetConsumption(hasNext) method - - Records start time on first fetch, calculates total on last fetch - - Attributes: - result_set_ready_latency_millis (int): Time until query results are ready (execution phase) - result_set_consumption_latency_millis (int): Time spent fetching/consuming results (fetch phase) - - """ - - result_set_ready_latency_millis: Optional[int] = None - result_set_consumption_latency_millis: Optional[int] = None - - -@dataclass -class OperationDetail(JsonSerializableMixin): - """ - Contains detailed information about the operation being performed. - - Attributes: - n_operation_status_calls (int): Number of status polling calls made - operation_status_latency_millis (int): Total latency of all status calls - operation_type (str): Specific operation type (e.g., EXECUTE_STATEMENT, LIST_TABLES, CANCEL_STATEMENT) - is_internal_call (bool): Whether this is an internal driver operation - """ - - n_operation_status_calls: Optional[int] = None - operation_status_latency_millis: Optional[int] = None - operation_type: Optional[str] = None - is_internal_call: Optional[bool] = None - - @dataclass class SqlExecutionEvent(JsonSerializableMixin): """ @@ -223,10 +122,7 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made - chunk_id (int): ID of the chunk if applicable (used for error tracking) - chunk_details (ChunkDetails): Aggregated chunk download metrics - result_latency (ResultLatency): Latency breakdown by execution phase - operation_detail (OperationDetail): Detailed operation information + chunk_id (int): ID of the chunk if applicable """ statement_type: StatementType @@ -234,9 +130,6 @@ class SqlExecutionEvent(JsonSerializableMixin): execution_result: ExecutionResultFormat retry_count: Optional[int] chunk_id: Optional[int] - chunk_details: Optional[ChunkDetails] = None - result_latency: Optional[ResultLatency] = None - operation_detail: Optional[OperationDetail] = None @dataclass diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 892485a4a..c7c4289ec 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,9 +3,6 @@ import logging import json from concurrent.futures import ThreadPoolExecutor, wait -from queue import Queue, Full -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures import Future from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -43,11 +40,6 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod -from databricks.sql.telemetry.telemetry_push_client import ( - ITelemetryPushClient, - TelemetryPushClient, - CircuitBreakerTelemetryPushClient, -) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -116,21 +108,18 @@ def get_auth_flow(auth_provider): @staticmethod def is_telemetry_enabled(connection: "Connection") -> bool: - # Fast path: force enabled - skip feature flag fetch entirely if connection.force_enable_telemetry: return True - # Fast path: disabled - no need to check feature flag - if not connection.enable_telemetry: + if connection.enable_telemetry: + context = FeatureFlagsContextFactory.get_instance(connection) + flag_value = context.get_flag_value( + TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False + ) + return str(flag_value).lower() == "true" + else: return False - # Only fetch feature flags when enable_telemetry=True and not forced - context = FeatureFlagsContextFactory.get_instance(connection) - flag_value = context.get_flag_value( - TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False - ) - return str(flag_value).lower() == "true" - class NoopTelemetryClient(BaseTelemetryClient): """ @@ -176,26 +165,23 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, - telemetry_enabled: bool, - session_id_hex: str, + telemetry_enabled, + session_id_hex, auth_provider, - host_url: str, + host_url, executor, - batch_size: int, + batch_size, client_context, - ) -> None: + ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None + self._events_batch = [] + self._lock = threading.RLock() self._pending_futures = set() - - # OPTIMIZATION: Use lock-free Queue instead of list + lock - # Queue is thread-safe internally and has better performance under concurrency - self._events_queue: Queue[TelemetryFrontendLog] = Queue(maxsize=batch_size * 2) - self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -203,41 +189,12 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) - # Create telemetry push client based on circuit breaker enabled flag - if client_context.telemetry_circuit_breaker_enabled: - # Create circuit breaker telemetry push client - # (circuit breakers created on-demand) - self._telemetry_push_client: ITelemetryPushClient = ( - CircuitBreakerTelemetryPushClient( - TelemetryPushClient(self._http_client), - host_url, - ) - ) - else: - # Circuit breaker disabled - use direct telemetry push client - self._telemetry_push_client = TelemetryPushClient(self._http_client) - def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) - - # OPTIMIZATION: Use non-blocking put with queue - # No explicit lock needed - Queue is thread-safe internally - try: - self._events_queue.put_nowait(event) - except Full: - # Queue is full, trigger immediate flush - logger.debug("Event queue full, triggering flush") - self._flush() - # Try again after flush - try: - self._events_queue.put_nowait(event) - except Full: - # Still full, drop event (acceptable for telemetry) - logger.debug("Dropped telemetry event - queue still full") - - # Check if we should flush based on queue size - if self._events_queue.qsize() >= self._batch_size: + with self._lock: + self._events_batch.append(event) + if len(self._events_batch) >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) @@ -245,16 +202,9 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - # OPTIMIZATION: Drain queue without locks - # Collect all events currently in the queue - events_to_flush = [] - while not self._events_queue.empty(): - try: - event = self._events_queue.get_nowait() - events_to_flush.append(event) - except: - # Queue is empty - break + with self._lock: + events_to_flush = self._events_batch.copy() + self._events_batch = [] if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) @@ -307,7 +257,7 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._telemetry_push_client.request( + response = self._http_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response @@ -443,9 +393,9 @@ class TelemetryClientFactory: It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ - _clients: Dict[str, BaseTelemetryClient] = ( - {} - ) # Map of session_id_hex -> BaseTelemetryClient + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -456,7 +406,7 @@ class TelemetryClientFactory: # Shared flush thread for all clients _flush_thread = None _flush_event = threading.Event() - _flush_interval_seconds = 300 # 5 minutes + _flush_interval_seconds = 90 DEFAULT_BATCH_SIZE = 100 @@ -546,21 +496,21 @@ def initialize_telemetry_client( session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[session_id_hex] = ( - TelemetryClient( - telemetry_enabled=telemetry_enabled, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, - batch_size=batch_size, - client_context=client_context, - ) + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + batch_size=batch_size, + client_context=client_context, ) else: - TelemetryClientFactory._clients[session_id_hex] = ( - NoopTelemetryClient() - ) + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py deleted file mode 100644 index 461a57738..000000000 --- a/src/databricks/sql/telemetry/telemetry_push_client.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Telemetry push client interface and implementations. - -This module provides an interface for telemetry push clients with two implementations: -1. TelemetryPushClient - Direct HTTP client implementation -2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation -""" - -import logging -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional - -try: - from urllib3 import BaseHTTPResponse -except ImportError: - from urllib3 import HTTPResponse as BaseHTTPResponse -from pybreaker import CircuitBreakerError - -from databricks.sql.common.unified_http_client import UnifiedHttpClient -from databricks.sql.common.http import HttpMethod -from databricks.sql.exc import ( - TelemetryRateLimitError, - TelemetryNonRateLimitError, - RequestError, -) -from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager - -logger = logging.getLogger(__name__) - - -class ITelemetryPushClient(ABC): - """Interface for telemetry push clients.""" - - @abstractmethod - def request( - self, - method: HttpMethod, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs, - ) -> BaseHTTPResponse: - """Make an HTTP request.""" - pass - - -class TelemetryPushClient(ITelemetryPushClient): - """Direct HTTP client implementation for telemetry requests.""" - - def __init__(self, http_client: UnifiedHttpClient): - """ - Initialize the telemetry push client. - - Args: - http_client: The underlying HTTP client - """ - self._http_client = http_client - logger.debug("TelemetryPushClient initialized") - - def request( - self, - method: HttpMethod, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs, - ) -> BaseHTTPResponse: - """Make an HTTP request using the underlying HTTP client.""" - return self._http_client.request(method, url, headers, **kwargs) - - -class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): - """Circuit breaker wrapper implementation for telemetry requests.""" - - def __init__(self, delegate: ITelemetryPushClient, host: str): - """ - Initialize the circuit breaker telemetry push client. - - Args: - delegate: The underlying telemetry push client to wrap - host: The hostname for circuit breaker identification - """ - self._delegate = delegate - self._host = host - - # Get circuit breaker for this host (creates if doesn't exist) - self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) - - logger.debug( - "CircuitBreakerTelemetryPushClient initialized for host %s", - host, - ) - - def _make_request_and_check_status( - self, - method: HttpMethod, - url: str, - headers: Optional[Dict[str, str]], - **kwargs, - ) -> BaseHTTPResponse: - """ - Make the request and check response status. - - Raises TelemetryRateLimitError for 429/503 (circuit breaker counts these). - Wraps other errors in TelemetryNonRateLimitError (circuit breaker excludes these). - - Args: - method: HTTP method - url: Request URL - headers: Request headers - **kwargs: Additional request parameters - - Returns: - HTTP response - - Raises: - TelemetryRateLimitError: For 429/503 status codes (circuit breaker counts) - TelemetryNonRateLimitError: For other errors (circuit breaker excludes) - """ - try: - response = self._delegate.request(method, url, headers, **kwargs) - - # Check for rate limiting or service unavailable - if response.status in [429, 503]: - logger.warning( - "Telemetry endpoint returned %d for host %s, triggering circuit breaker", - response.status, - self._host, - ) - raise TelemetryRateLimitError( - f"Telemetry endpoint rate limited or unavailable: {response.status}" - ) - - return response - - except Exception as e: - # Don't catch TelemetryRateLimitError - let it propagate to circuit breaker - if isinstance(e, TelemetryRateLimitError): - raise - - # Check if it's a RequestError with rate limiting status code (exhausted retries) - if isinstance(e, RequestError): - http_code = ( - e.context.get("http-code") - if hasattr(e, "context") and e.context - else None - ) - - if http_code in [429, 503]: - logger.debug( - "Telemetry retries exhausted with status %d for host %s, triggering circuit breaker", - http_code, - self._host, - ) - raise TelemetryRateLimitError( - f"Telemetry rate limited after retries: {http_code}" - ) - - # NOT rate limiting (500 errors, network errors, timeouts, etc.) - # Wrap in TelemetryNonRateLimitError so circuit breaker excludes it - logger.debug( - "Non-rate-limit telemetry error for host %s: %s, wrapping to exclude from circuit breaker", - self._host, - e, - ) - raise TelemetryNonRateLimitError(e) from e - - def request( - self, - method: HttpMethod, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs, - ) -> BaseHTTPResponse: - """ - Make an HTTP request with circuit breaker protection. - - Circuit breaker only opens for TelemetryRateLimitError (429/503 responses). - Other errors are wrapped in TelemetryNonRateLimitError and excluded from circuit breaker. - All exceptions propagate to caller (TelemetryClient callback handles them). - """ - try: - # Use circuit breaker to protect the request - # TelemetryRateLimitError will trigger circuit breaker - # TelemetryNonRateLimitError is excluded from circuit breaker - return self._circuit_breaker.call( - self._make_request_and_check_status, - method, - url, - headers, - **kwargs, - ) - - except TelemetryNonRateLimitError as e: - # Unwrap and re-raise original exception - # Circuit breaker didn't count this, but caller should handle it - logger.debug( - "Non-rate-limit telemetry error for host %s, re-raising original: %s", - self._host, - e.original_exception, - ) - raise e.original_exception from e - # All other exceptions (TelemetryRateLimitError, CircuitBreakerError) propagate as-is diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b46784b10..9f96e8743 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -922,7 +922,4 @@ def build_client_context(server_hostname: str, version: str, **kwargs): proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), - telemetry_circuit_breaker_enabled=kwargs.get( - "_telemetry_circuit_breaker_enabled" - ), ) diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py deleted file mode 100644 index 45c494d19..000000000 --- a/tests/e2e/test_circuit_breaker.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -E2E tests for circuit breaker functionality in telemetry. - -This test suite verifies: -1. Circuit breaker opens after rate limit failures (429/503) -2. Circuit breaker blocks subsequent calls while open -3. Circuit breaker does not trigger for non-rate-limit errors -4. Circuit breaker can be disabled via configuration flag -5. Circuit breaker closes after reset timeout - -Run with: - pytest tests/e2e/test_circuit_breaker.py -v -s -""" - -import time -from unittest.mock import patch, MagicMock - -import pytest -from pybreaker import STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN -from urllib3 import HTTPResponse - -import databricks.sql as sql -from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager - - -@pytest.fixture(autouse=True) -def aggressive_circuit_breaker_config(): - """ - Configure circuit breaker to be aggressive for faster testing. - Opens after 2 failures instead of 20, with 5 second timeout. - """ - from databricks.sql.telemetry import circuit_breaker_manager - - original_minimum_calls = circuit_breaker_manager.MINIMUM_CALLS - original_reset_timeout = circuit_breaker_manager.RESET_TIMEOUT - - circuit_breaker_manager.MINIMUM_CALLS = 2 - circuit_breaker_manager.RESET_TIMEOUT = 5 - - CircuitBreakerManager._instances.clear() - - yield - - circuit_breaker_manager.MINIMUM_CALLS = original_minimum_calls - circuit_breaker_manager.RESET_TIMEOUT = original_reset_timeout - CircuitBreakerManager._instances.clear() - - -class TestCircuitBreakerTelemetry: - """Tests for circuit breaker functionality with telemetry""" - - @pytest.fixture(autouse=True) - def get_details(self, connection_details): - """Get connection details from pytest fixture""" - self.arguments = connection_details.copy() - - def create_mock_response(self, status_code): - """Helper to create mock HTTP response.""" - response = MagicMock(spec=HTTPResponse) - response.status = status_code - response.data = { - 429: b"Too Many Requests", - 503: b"Service Unavailable", - 500: b"Internal Server Error", - }.get(status_code, b"Response") - return response - - @pytest.mark.parametrize("status_code,should_trigger", [ - (429, True), - (503, True), - (500, False), - ]) - def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): - """ - Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). - """ - request_count = {"count": 0} - - def mock_request(*args, **kwargs): - request_count["count"] += 1 - return self.create_mock_response(status_code) - - with patch( - "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", - side_effect=mock_request, - ): - with sql.connect( - server_hostname=self.arguments["host"], - http_path=self.arguments["http_path"], - access_token=self.arguments.get("access_token"), - force_enable_telemetry=True, - telemetry_batch_size=1, - _telemetry_circuit_breaker_enabled=True, - ) as conn: - circuit_breaker = CircuitBreakerManager.get_circuit_breaker( - self.arguments["host"] - ) - - assert circuit_breaker.current_state == STATE_CLOSED - - cursor = conn.cursor() - - # Execute queries to trigger telemetry - for i in range(1, 6): - cursor.execute(f"SELECT {i}") - cursor.fetchone() - time.sleep(0.5) - - if should_trigger: - # Circuit should be OPEN after 2 rate-limit failures - assert circuit_breaker.current_state == STATE_OPEN - assert circuit_breaker.fail_counter == 2 - - # Track requests before another query - requests_before = request_count["count"] - cursor.execute("SELECT 99") - cursor.fetchone() - time.sleep(1) - - # No new telemetry requests (circuit is open) - assert request_count["count"] == requests_before - else: - # Circuit should remain CLOSED for non-rate-limit errors - assert circuit_breaker.current_state == STATE_CLOSED - assert circuit_breaker.fail_counter == 0 - assert request_count["count"] >= 5 - - def test_circuit_breaker_disabled_allows_all_calls(self): - """ - Verify that when circuit breaker is disabled, all calls go through - even with rate limit errors. - """ - request_count = {"count": 0} - - def mock_rate_limited_request(*args, **kwargs): - request_count["count"] += 1 - return self.create_mock_response(429) - - with patch( - "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", - side_effect=mock_rate_limited_request, - ): - with sql.connect( - server_hostname=self.arguments["host"], - http_path=self.arguments["http_path"], - access_token=self.arguments.get("access_token"), - force_enable_telemetry=True, - telemetry_batch_size=1, - _telemetry_circuit_breaker_enabled=False, # Disabled - ) as conn: - cursor = conn.cursor() - - for i in range(5): - cursor.execute(f"SELECT {i}") - cursor.fetchone() - time.sleep(0.3) - - assert request_count["count"] >= 5 - - def test_circuit_breaker_recovers_after_reset_timeout(self): - """ - Verify circuit breaker transitions to HALF_OPEN after reset timeout - and eventually CLOSES if requests succeed. - """ - request_count = {"count": 0} - fail_requests = {"enabled": True} - - def mock_conditional_request(*args, **kwargs): - request_count["count"] += 1 - status = 429 if fail_requests["enabled"] else 200 - return self.create_mock_response(status) - - with patch( - "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", - side_effect=mock_conditional_request, - ): - with sql.connect( - server_hostname=self.arguments["host"], - http_path=self.arguments["http_path"], - access_token=self.arguments.get("access_token"), - force_enable_telemetry=True, - telemetry_batch_size=1, - _telemetry_circuit_breaker_enabled=True, - ) as conn: - circuit_breaker = CircuitBreakerManager.get_circuit_breaker( - self.arguments["host"] - ) - - cursor = conn.cursor() - - # Trigger failures to open circuit - cursor.execute("SELECT 1") - cursor.fetchone() - time.sleep(1) - - cursor.execute("SELECT 2") - cursor.fetchone() - time.sleep(2) - - assert circuit_breaker.current_state == STATE_OPEN - - # Wait for reset timeout (5 seconds in test) - time.sleep(6) - - # Now make requests succeed - fail_requests["enabled"] = False - - # Execute query to trigger HALF_OPEN state - cursor.execute("SELECT 3") - cursor.fetchone() - time.sleep(1) - - # Circuit should be recovering - assert circuit_breaker.current_state in [ - STATE_HALF_OPEN, - STATE_CLOSED, - ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" - - # Execute more queries to fully recover - cursor.execute("SELECT 4") - cursor.fetchone() - time.sleep(1) - - current_state = circuit_breaker.current_state - assert current_state in [ - STATE_CLOSED, - STATE_HALF_OPEN, - ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py deleted file mode 100644 index 917c8e5eb..000000000 --- a/tests/e2e/test_telemetry_e2e.py +++ /dev/null @@ -1,343 +0,0 @@ -""" -E2E test for telemetry - verifies telemetry behavior with different scenarios -""" -import time -import threading -import logging -from contextlib import contextmanager -from unittest.mock import patch -import pytest -from concurrent.futures import wait - -import databricks.sql as sql -from databricks.sql.telemetry.telemetry_client import ( - TelemetryClient, - TelemetryClientFactory, -) - -log = logging.getLogger(__name__) - - -class TelemetryTestBase: - """Simplified test base class for telemetry e2e tests""" - - @pytest.fixture(autouse=True) - def get_details(self, connection_details): - self.arguments = connection_details.copy() - - def connection_params(self): - return { - "server_hostname": self.arguments["host"], - "http_path": self.arguments["http_path"], - "access_token": self.arguments.get("access_token"), - } - - @contextmanager - def connection(self, extra_params=()): - connection_params = dict(self.connection_params(), **dict(extra_params)) - log.info("Connecting with args: {}".format(connection_params)) - conn = sql.connect(**connection_params) - try: - yield conn - finally: - conn.close() - - -class TestTelemetryE2E(TelemetryTestBase): - """E2E tests for telemetry scenarios""" - - @pytest.fixture(autouse=True) - def telemetry_setup_teardown(self): - """Clean up telemetry client state before and after each test""" - try: - yield - finally: - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._executor = None - TelemetryClientFactory._stop_flush_thread() - TelemetryClientFactory._initialized = False - - @pytest.fixture - def telemetry_interceptors(self): - """Setup reusable telemetry interceptors as a fixture""" - capture_lock = threading.Lock() - captured_events = [] - captured_futures = [] - - original_export = TelemetryClient._export_event - original_callback = TelemetryClient._telemetry_request_callback - - def export_wrapper(self_client, event): - with capture_lock: - captured_events.append(event) - return original_export(self_client, event) - - def callback_wrapper(self_client, future, sent_count): - with capture_lock: - captured_futures.append(future) - original_callback(self_client, future, sent_count) - - return captured_events, captured_futures, export_wrapper, callback_wrapper - - # ==================== ASSERTION HELPERS ==================== - - def assert_system_config(self, event): - """Assert system configuration fields""" - sys_config = event.entry.sql_driver_log.system_configuration - assert sys_config is not None - - # Check all required fields are non-empty - for field in ['driver_name', 'driver_version', 'os_name', 'os_version', - 'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor', - 'locale_name', 'char_set_encoding']: - value = getattr(sys_config, field) - assert value and len(value) > 0, f"{field} should not be None or empty" - - assert sys_config.driver_name == "Databricks SQL Python Connector" - - def assert_connection_params(self, event, expected_http_path=None): - """Assert connection parameters""" - conn_params = event.entry.sql_driver_log.driver_connection_params - assert conn_params is not None - assert conn_params.http_path - assert conn_params.host_info is not None - assert conn_params.auth_mech is not None - - if expected_http_path: - assert conn_params.http_path == expected_http_path - - if conn_params.socket_timeout is not None: - assert conn_params.socket_timeout > 0 - - def assert_statement_execution(self, event): - """Assert statement execution details""" - sql_op = event.entry.sql_driver_log.sql_operation - assert sql_op is not None - assert sql_op.statement_type is not None - assert sql_op.execution_result is not None - assert hasattr(sql_op, "retry_count") - - if sql_op.retry_count is not None: - assert sql_op.retry_count >= 0 - - latency = event.entry.sql_driver_log.operation_latency_ms - assert latency is not None and latency >= 0 - - def assert_error_info(self, event, expected_error_name=None): - """Assert error information""" - error_info = event.entry.sql_driver_log.error_info - assert error_info is not None - assert error_info.error_name and len(error_info.error_name) > 0 - assert error_info.stack_trace and len(error_info.stack_trace) > 0 - - if expected_error_name: - assert error_info.error_name == expected_error_name - - def verify_events(self, captured_events, captured_futures, expected_count): - """Common verification for event count and HTTP responses""" - if expected_count == 0: - assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}" - assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}" - else: - assert len(captured_events) == expected_count, \ - f"Expected {expected_count} events, got {len(captured_events)}" - - time.sleep(2) - done, _ = wait(captured_futures, timeout=10) - assert len(done) == expected_count, \ - f"Expected {expected_count} responses, got {len(done)}" - - for future in done: - response = future.result() - assert 200 <= response.status < 300 - - # Assert common fields for all events - for event in captured_events: - self.assert_system_config(event) - self.assert_connection_params(event, self.arguments["http_path"]) - - # ==================== PARAMETERIZED TESTS ==================== - - @pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [ - (True, False, 2, "enable_on_force_off"), - (False, True, 2, "enable_off_force_on"), - (False, False, 0, "both_off"), - (None, None, 0, "default_behavior"), - ]) - def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, - force_enable, expected_count, test_id): - """Test telemetry behavior with different flag combinations""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ - telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - extra_params = {"telemetry_batch_size": 1} - if enable_telemetry is not None: - extra_params["enable_telemetry"] = enable_telemetry - if force_enable is not None: - extra_params["force_enable_telemetry"] = force_enable - - with self.connection(extra_params=extra_params) as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 1") - cursor.fetchone() - - self.verify_events(captured_events, captured_futures, expected_count) - - # Assert statement execution on latency event (if events exist) - if expected_count > 0: - self.assert_statement_execution(captured_events[-1]) - - @pytest.mark.parametrize("query,expected_error", [ - ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), - ("SELECT * FROM non_existent_table_xyz_12345", None), - ]) - def test_sql_errors(self, telemetry_interceptors, query, expected_error): - """Test telemetry captures error information for different SQL errors""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ - telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - }) as conn: - with conn.cursor() as cursor: - with pytest.raises(Exception): - cursor.execute(query) - cursor.fetchone() - - time.sleep(2) - wait(captured_futures, timeout=10) - - assert len(captured_events) >= 1 - - # Find event with error_info - error_event = next((e for e in captured_events - if e.entry.sql_driver_log.error_info), None) - assert error_event is not None - - self.assert_system_config(error_event) - self.assert_connection_params(error_event, self.arguments["http_path"]) - self.assert_error_info(error_event, expected_error) - - def test_metadata_operation(self, telemetry_interceptors): - """Test telemetry for metadata operations (getCatalogs)""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ - telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - }) as conn: - with conn.cursor() as cursor: - catalogs = cursor.catalogs() - catalogs.fetchall() - - time.sleep(2) - wait(captured_futures, timeout=10) - - assert len(captured_events) >= 1 - for event in captured_events: - self.assert_system_config(event) - self.assert_connection_params(event, self.arguments["http_path"]) - - def test_direct_results(self, telemetry_interceptors): - """Test telemetry with direct results (use_cloud_fetch=False)""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ - telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - "use_cloud_fetch": False, - }) as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT 100") - result = cursor.fetchall() - assert len(result) == 1 and result[0][0] == 100 - - time.sleep(2) - wait(captured_futures, timeout=10) - - assert len(captured_events) >= 2 - for event in captured_events: - self.assert_system_config(event) - self.assert_connection_params(event, self.arguments["http_path"]) - - self.assert_statement_execution(captured_events[-1]) - - @pytest.mark.parametrize("close_type", [ - "context_manager", - "explicit_cursor", - "explicit_connection", - "implicit_fetchall", - ]) - def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, - close_type): - """Test telemetry with cloud fetch using different resource closing patterns""" - captured_events, captured_futures, export_wrapper, callback_wrapper = \ - telemetry_interceptors - - with patch.object(TelemetryClient, "_export_event", export_wrapper), \ - patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): - - if close_type == "explicit_connection": - # Test explicit connection close - conn = sql.connect( - **self.connection_params(), - force_enable_telemetry=True, - telemetry_batch_size=1, - use_cloud_fetch=True, - ) - cursor = conn.cursor() - cursor.execute("SELECT * FROM range(1000)") - result = cursor.fetchall() - assert len(result) == 1000 - conn.close() - else: - # Other patterns use connection context manager - with self.connection(extra_params={ - "force_enable_telemetry": True, - "telemetry_batch_size": 1, - "use_cloud_fetch": True, - }) as conn: - if close_type == "context_manager": - with conn.cursor() as cursor: - cursor.execute("SELECT * FROM range(1000)") - result = cursor.fetchall() - assert len(result) == 1000 - - elif close_type == "explicit_cursor": - cursor = conn.cursor() - cursor.execute("SELECT * FROM range(1000)") - result = cursor.fetchall() - assert len(result) == 1000 - cursor.close() - - elif close_type == "implicit_fetchall": - cursor = conn.cursor() - cursor.execute("SELECT * FROM range(1000)") - result = cursor.fetchall() - assert len(result) == 1000 - - time.sleep(2) - wait(captured_futures, timeout=10) - - assert len(captured_events) >= 2 - for event in captured_events: - self.assert_system_config(event) - self.assert_connection_params(event, self.arguments["http_path"]) - - self.assert_statement_execution(captured_events[-1]) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py deleted file mode 100644 index d4f6a790a..000000000 --- a/tests/e2e/test_transactions.py +++ /dev/null @@ -1,598 +0,0 @@ -""" -End-to-end integration tests for Multi-Statement Transaction (MST) APIs. - -These tests verify: -- autocommit property (getter/setter) -- commit() and rollback() methods -- get_transaction_isolation() and set_transaction_isolation() methods -- Transaction error handling - -Requirements: -- DBSQL warehouse that supports Multi-Statement Transactions (MST) -- Test environment configured via test.env file or environment variables - -Setup: -Set the following environment variables: -- DATABRICKS_SERVER_HOSTNAME -- DATABRICKS_HTTP_PATH -- DATABRICKS_ACCESS_TOKEN (or use OAuth) - -Usage: - pytest tests/e2e/test_transactions.py -v -""" - -import logging -import os -import pytest -from typing import Any, Dict - -import databricks.sql as sql -from databricks.sql import TransactionError, NotSupportedError, InterfaceError - -logger = logging.getLogger(__name__) - - -@pytest.mark.skip( - reason="Test environment does not yet support multi-statement transactions" -) -class TestTransactions: - """E2E tests for transaction control methods (MST support).""" - - # Test table name - TEST_TABLE_NAME = "transaction_test_table" - - @pytest.fixture(autouse=True) - def setup_and_teardown(self, connection_details): - """Setup test environment before each test and cleanup after.""" - self.connection_params = { - "server_hostname": connection_details["host"], - "http_path": connection_details["http_path"], - "access_token": connection_details.get("access_token"), - "ignore_transactions": False, # Enable actual transaction functionality for these tests - } - - # Get catalog and schema from environment or use defaults - self.catalog = os.getenv("DATABRICKS_CATALOG", "main") - self.schema = os.getenv("DATABRICKS_SCHEMA", "default") - - # Create connection for setup - self.connection = sql.connect(**self.connection_params) - - # Setup: Create test table - self._create_test_table() - - yield - - # Teardown: Cleanup - self._cleanup() - - def _get_fully_qualified_table_name(self) -> str: - """Get the fully qualified table name.""" - return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" - - def _create_test_table(self): - """Create the test table with Delta format and MST support.""" - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - - try: - # Drop if exists - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") - - # Create table with Delta and catalog-owned feature for MST compatibility - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table_name} - (id INT, value STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ - ) - - logger.info(f"Created test table: {fq_table_name}") - finally: - cursor.close() - - def _cleanup(self): - """Cleanup after test: rollback pending transactions, drop table, close connection.""" - try: - # Try to rollback any pending transaction - if ( - self.connection - and self.connection.open - and not self.connection.autocommit - ): - try: - self.connection.rollback() - except Exception as e: - logger.debug( - f"Rollback during cleanup failed (may be expected): {e}" - ) - - # Reset to autocommit mode - try: - self.connection.autocommit = True - except Exception as e: - logger.debug(f"Reset autocommit during cleanup failed: {e}") - - # Drop test table - if self.connection and self.connection.open: - fq_table_name = self._get_fully_qualified_table_name() - cursor = self.connection.cursor() - try: - cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") - logger.info(f"Dropped test table: {fq_table_name}") - except Exception as e: - logger.warning(f"Failed to drop test table: {e}") - finally: - cursor.close() - - finally: - # Close connection - if self.connection: - self.connection.close() - - # ==================== BASIC AUTOCOMMIT TESTS ==================== - - def test_default_autocommit_is_true(self): - """Test that new connection defaults to autocommit=true.""" - assert ( - self.connection.autocommit is True - ), "New connection should have autocommit=true by default" - - def test_set_autocommit_to_false(self): - """Test successfully setting autocommit to false.""" - self.connection.autocommit = False - assert ( - self.connection.autocommit is False - ), "autocommit should be false after setting to false" - - def test_set_autocommit_to_true(self): - """Test successfully setting autocommit back to true.""" - # First disable - self.connection.autocommit = False - assert self.connection.autocommit is False - - # Then enable - self.connection.autocommit = True - assert ( - self.connection.autocommit is True - ), "autocommit should be true after setting to true" - - # ==================== COMMIT TESTS ==================== - - def test_commit_single_insert(self): - """Test successfully committing a transaction with single INSERT.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" - ) - cursor.close() - - # Commit - self.connection.commit() - - # Verify data is persisted using a new connection - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result is not None, "Should find inserted row after commit" - assert result[0] == "test_value", "Value should match inserted value" - finally: - verify_conn.close() - - def test_commit_multiple_inserts(self): - """Test successfully committing a transaction with multiple INSERTs.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # Insert multiple rows - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") - cursor.close() - - self.connection.commit() - - # Verify all rows persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 3, "Should have 3 rows after commit" - finally: - verify_conn.close() - - # ==================== ROLLBACK TESTS ==================== - - def test_rollback_single_insert(self): - """Test successfully rolling back a transaction.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # Insert data - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" - ) - cursor.close() - - # Rollback - self.connection.rollback() - - # Verify data is NOT persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" - ) - result = verify_cursor.fetchone() - verify_cursor.close() - - assert result[0] == 0, "Rolled back data should not be persisted" - finally: - verify_conn.close() - - # ==================== SEQUENTIAL TRANSACTION TESTS ==================== - - def test_multiple_sequential_transactions(self): - """Test executing multiple sequential transactions (commit, commit, rollback).""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") - cursor.close() - self.connection.commit() - - # Second transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") - cursor.close() - self.connection.commit() - - # Third transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") - cursor.close() - self.connection.rollback() - - # Verify only first two transactions persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" - ) - result = verify_cursor.fetchone() - assert result[0] == 2, "Should have 2 committed rows" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") - result = verify_cursor.fetchone() - assert result[0] == 0, "Rolled back row should not exist" - verify_cursor.close() - finally: - verify_conn.close() - - def test_auto_start_transaction_after_commit(self): - """Test that new transaction automatically starts after commit.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # First transaction - commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.commit() - - # New transaction should start automatically - insert and rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.rollback() - - # Verify: first committed, second rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 1, "First insert should be committed" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 0, "Second insert should be rolled back" - verify_cursor.close() - finally: - verify_conn.close() - - def test_auto_start_transaction_after_rollback(self): - """Test that new transaction automatically starts after rollback.""" - fq_table_name = self._get_fully_qualified_table_name() - - self.connection.autocommit = False - - # First transaction - rollback - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") - cursor.close() - self.connection.rollback() - - # New transaction should start automatically - insert and commit - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") - cursor.close() - self.connection.commit() - - # Verify: first rolled back, second committed - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == 0, "First insert should be rolled back" - - verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") - result = verify_cursor.fetchone() - assert result[0] == 1, "Second insert should be committed" - verify_cursor.close() - finally: - verify_conn.close() - - # ==================== UPDATE/DELETE OPERATION TESTS ==================== - - def test_update_in_transaction(self): - """Test UPDATE operation in transaction.""" - fq_table_name = self._get_fully_qualified_table_name() - - # First insert a row with autocommit - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" - ) - cursor.close() - - # Start transaction and update - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") - cursor.close() - self.connection.commit() - - # Verify update persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") - result = verify_cursor.fetchone() - assert result[0] == "updated", "Value should be updated after commit" - verify_cursor.close() - finally: - verify_conn.close() - - # ==================== MULTI-TABLE TRANSACTION TESTS ==================== - - def test_multi_table_transaction_commit(self): - """Test atomic commit across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ - ) - cursor.close() - - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" - ) - cursor.close() - - # Commit both atomically - self.connection.commit() - - # Verify both inserts persisted - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table1 insert should be committed" - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" - ) - result = verify_cursor.fetchone() - assert result[0] == 1, "Table2 insert should be committed" - - verify_cursor.close() - finally: - verify_conn.close() - - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - def test_multi_table_transaction_rollback(self): - """Test atomic rollback across multiple tables.""" - fq_table1_name = self._get_fully_qualified_table_name() - table2_name = self.TEST_TABLE_NAME + "_2" - fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" - - # Create second table - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {fq_table2_name} - (id INT, category STRING) - USING DELTA - TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') - """ - ) - cursor.close() - - try: - # Start transaction and insert into both tables - self.connection.autocommit = False - - cursor = self.connection.cursor() - cursor.execute( - f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" - ) - cursor.execute( - f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" - ) - cursor.close() - - # Rollback both atomically - self.connection.rollback() - - # Verify both inserts were rolled back - verify_conn = sql.connect(**self.connection_params) - try: - verify_cursor = verify_conn.cursor() - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" - ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table1 insert should be rolled back" - - verify_cursor.execute( - f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" - ) - result = verify_cursor.fetchone() - assert result[0] == 0, "Table2 insert should be rolled back" - - verify_cursor.close() - finally: - verify_conn.close() - - finally: - # Cleanup second table - self.connection.autocommit = True - cursor = self.connection.cursor() - cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") - cursor.close() - - # ==================== ERROR HANDLING TESTS ==================== - - def test_set_autocommit_during_active_transaction(self): - """Test that setting autocommit during an active transaction throws error.""" - fq_table_name = self._get_fully_qualified_table_name() - - # Start transaction - self.connection.autocommit = False - cursor = self.connection.cursor() - cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") - cursor.close() - - # Try to set autocommit=True during active transaction - with pytest.raises(TransactionError) as exc_info: - self.connection.autocommit = True - - # Verify error message mentions autocommit or active transaction - error_msg = str(exc_info.value).lower() - assert ( - "autocommit" in error_msg or "active transaction" in error_msg - ), "Error should mention autocommit or active transaction" - - # Cleanup - rollback the transaction - self.connection.rollback() - - def test_commit_without_active_transaction_throws_error(self): - """Test that commit() throws error when autocommit=true (no active transaction).""" - # Ensure autocommit is true (default) - assert self.connection.autocommit is True - - # Attempt commit without active transaction should throw - with pytest.raises(TransactionError) as exc_info: - self.connection.commit() - - # Verify error message indicates no active transaction - error_message = str(exc_info.value) - assert ( - "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message - or "no active transaction" in error_message.lower() - ), "Error should indicate no active transaction" - - def test_rollback_without_active_transaction_is_safe(self): - """Test that rollback() without active transaction is a safe no-op.""" - # With autocommit=true (no active transaction) - assert self.connection.autocommit is True - - # ROLLBACK should be safe (no exception) - self.connection.rollback() - - # Verify connection is still usable - assert self.connection.autocommit is True - assert self.connection.open is True - - # ==================== TRANSACTION ISOLATION TESTS ==================== - - def test_get_transaction_isolation_returns_repeatable_read(self): - """Test that get_transaction_isolation() returns REPEATABLE_READ.""" - isolation_level = self.connection.get_transaction_isolation() - assert ( - isolation_level == "REPEATABLE_READ" - ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" - - def test_set_transaction_isolation_accepts_repeatable_read(self): - """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" - # Should not raise - these are all valid formats - self.connection.set_transaction_isolation("REPEATABLE_READ") - self.connection.set_transaction_isolation("REPEATABLE READ") - self.connection.set_transaction_isolation("repeatable_read") - self.connection.set_transaction_isolation("repeatable read") - - def test_set_transaction_isolation_rejects_unsupported_level(self): - """Test that set_transaction_isolation() rejects unsupported levels.""" - with pytest.raises(NotSupportedError) as exc_info: - self.connection.set_transaction_isolation("READ_COMMITTED") - - error_message = str(exc_info.value) - assert "not supported" in error_message.lower() - assert "READ_COMMITTED" in error_message diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py deleted file mode 100644 index 432ca1be3..000000000 --- a/tests/unit/test_circuit_breaker_http_client.py +++ /dev/null @@ -1,208 +0,0 @@ -""" -Unit tests for telemetry push client functionality. -""" - -import pytest -from unittest.mock import Mock, patch, MagicMock - -from databricks.sql.telemetry.telemetry_push_client import ( - ITelemetryPushClient, - TelemetryPushClient, - CircuitBreakerTelemetryPushClient, -) -from databricks.sql.common.http import HttpMethod -from pybreaker import CircuitBreakerError - - -class TestTelemetryPushClient: - """Test cases for TelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_http_client = Mock() - self.client = TelemetryPushClient(self.mock_http_client) - - def test_initialization(self): - """Test client initialization.""" - assert self.client._http_client == self.mock_http_client - - def test_request_delegates_to_http_client(self): - """Test that request delegates to underlying HTTP client.""" - mock_response = Mock() - self.mock_http_client.request.return_value = mock_response - - response = self.client.request(HttpMethod.POST, "https://test.com", {}) - - assert response == mock_response - self.mock_http_client.request.assert_called_once() - - def test_direct_client_has_no_circuit_breaker(self): - """Test that direct client does not have circuit breaker functionality.""" - # Direct client should work without circuit breaker - assert isinstance(self.client, TelemetryPushClient) - - -class TestCircuitBreakerTelemetryPushClient: - """Test cases for CircuitBreakerTelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_delegate = Mock(spec=ITelemetryPushClient) - self.host = "test-host.example.com" - self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - def test_initialization(self): - """Test client initialization.""" - assert self.client._delegate == self.mock_delegate - assert self.client._host == self.host - assert self.client._circuit_breaker is not None - - def test_request_enabled_success(self): - """Test successful request when circuit breaker is enabled.""" - mock_response = Mock() - self.mock_delegate.request.return_value = mock_response - - response = self.client.request(HttpMethod.POST, "https://test.com", {}) - - assert response == mock_response - self.mock_delegate.request.assert_called_once() - - def test_request_enabled_circuit_breaker_error(self): - """Test request when circuit breaker is open - should raise CircuitBreakerError.""" - # Mock circuit breaker to raise CircuitBreakerError - with patch.object( - self.client._circuit_breaker, - "call", - side_effect=CircuitBreakerError("Circuit is open"), - ): - # Circuit breaker open should raise (caller handles it) - with pytest.raises(CircuitBreakerError): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - def test_request_enabled_other_error(self): - """Test request when other error occurs - should raise original exception.""" - # Mock delegate to raise a different error (not rate limiting) - self.mock_delegate.request.side_effect = ValueError("Network error") - - # Non-rate-limit errors are unwrapped and raised - with pytest.raises(ValueError, match="Network error"): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - def test_is_circuit_breaker_enabled(self): - """Test checking if circuit breaker is enabled.""" - assert self.client._circuit_breaker is not None - - def test_circuit_breaker_state_logging(self): - """Test that circuit breaker errors are raised (no longer silent).""" - with patch.object( - self.client._circuit_breaker, - "call", - side_effect=CircuitBreakerError("Circuit is open"), - ): - # Should raise CircuitBreakerError (caller handles it) - with pytest.raises(CircuitBreakerError): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - def test_other_error_logging(self): - """Test that other errors are wrapped, logged, then unwrapped and raised.""" - with patch( - "databricks.sql.telemetry.telemetry_push_client.logger" - ) as mock_logger: - self.mock_delegate.request.side_effect = ValueError("Network error") - - # Should raise the original ValueError - with pytest.raises(ValueError, match="Network error"): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - # Check that debug was logged (for wrapping and/or unwrapping) - assert mock_logger.debug.call_count >= 1 - - -class TestCircuitBreakerTelemetryPushClientIntegration: - """Integration tests for CircuitBreakerTelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_delegate = Mock() - self.host = "test-host.example.com" - - def test_circuit_breaker_opens_after_failures(self): - """Test that circuit breaker opens after repeated failures (429/503 errors).""" - from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerManager, - MINIMUM_CALLS, - ) - from databricks.sql.exc import TelemetryRateLimitError - - # Clear any existing state - CircuitBreakerManager._instances.clear() - - client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - # Simulate rate limit failures (429) - mock_response = Mock() - mock_response.status = 429 - self.mock_delegate.request.return_value = mock_response - - # All calls should raise TelemetryRateLimitError - # After MINIMUM_CALLS failures, circuit breaker opens - rate_limit_error_count = 0 - circuit_breaker_error_count = 0 - - for i in range(MINIMUM_CALLS + 5): - try: - client.request(HttpMethod.POST, "https://test.com", {}) - except TelemetryRateLimitError: - rate_limit_error_count += 1 - except CircuitBreakerError: - circuit_breaker_error_count += 1 - - # Should have some rate limit errors before circuit opens, then circuit breaker errors - assert rate_limit_error_count >= MINIMUM_CALLS - 1 - assert circuit_breaker_error_count > 0 - - def test_circuit_breaker_recovers_after_success(self): - """Test that circuit breaker recovers after successful calls.""" - from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerManager, - MINIMUM_CALLS, - RESET_TIMEOUT, - ) - import time - - # Clear any existing state - CircuitBreakerManager._instances.clear() - - client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - # Simulate rate limit failures first (429) - from databricks.sql.exc import TelemetryRateLimitError - from pybreaker import CircuitBreakerError - - mock_rate_limit_response = Mock() - mock_rate_limit_response.status = 429 - self.mock_delegate.request.return_value = mock_rate_limit_response - - # Trigger enough rate limit failures to open circuit - for i in range(MINIMUM_CALLS + 5): - try: - client.request(HttpMethod.POST, "https://test.com", {}) - except (TelemetryRateLimitError, CircuitBreakerError): - pass # Expected - circuit breaker opens after MINIMUM_CALLS failures - - # Circuit should be open now - raises CircuitBreakerError - with pytest.raises(CircuitBreakerError): - client.request(HttpMethod.POST, "https://test.com", {}) - - # Wait for reset timeout - time.sleep(RESET_TIMEOUT + 1.0) - - # Simulate successful calls (200 response) - mock_success_response = Mock() - mock_success_response.status = 200 - self.mock_delegate.request.return_value = mock_success_response - - # Should work again with actual success response - response = client.request(HttpMethod.POST, "https://test.com", {}) - assert response is not None - assert response.status == 200 diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py deleted file mode 100644 index e8ed4e809..000000000 --- a/tests/unit/test_circuit_breaker_manager.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Unit tests for circuit breaker manager functionality. -""" - -import pytest -import threading -import time -from unittest.mock import Mock, patch - -from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerManager, - MINIMUM_CALLS, - RESET_TIMEOUT, - NAME_PREFIX as CIRCUIT_BREAKER_NAME, -) -from pybreaker import CircuitBreakerError - - -class TestCircuitBreakerManager: - """Test cases for CircuitBreakerManager.""" - - def setup_method(self): - """Set up test fixtures.""" - CircuitBreakerManager._instances.clear() - - def teardown_method(self): - """Clean up after tests.""" - CircuitBreakerManager._instances.clear() - - def test_get_circuit_breaker_creates_instance(self): - """Test getting circuit breaker creates instance with correct config.""" - breaker = CircuitBreakerManager.get_circuit_breaker("test-host") - - assert breaker.name == "telemetry-circuit-breaker-test-host" - assert breaker.fail_max == MINIMUM_CALLS - - def test_get_circuit_breaker_same_host_returns_same_instance(self): - """Test that same host returns same circuit breaker instance.""" - breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") - breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") - - assert breaker1 is breaker2 - - def test_get_circuit_breaker_different_hosts_return_different_instances(self): - """Test that different hosts return different circuit breaker instances.""" - breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") - breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") - - assert breaker1 is not breaker2 - assert breaker1.name != breaker2.name - - def test_thread_safety(self): - """Test thread safety of circuit breaker manager.""" - results = [] - - def get_breaker(host): - breaker = CircuitBreakerManager.get_circuit_breaker(host) - results.append(breaker) - - threads = [] - for i in range(10): - thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - assert len(results) == 10 - - # All breakers for same host should be same instance - host0_breakers = [b for b in results if b.name.endswith("host0")] - assert all(b is host0_breakers[0] for b in host0_breakers) - - -class TestCircuitBreakerIntegration: - """Integration tests for circuit breaker functionality.""" - - def setup_method(self): - """Set up test fixtures.""" - CircuitBreakerManager._instances.clear() - - def teardown_method(self): - """Clean up after tests.""" - CircuitBreakerManager._instances.clear() - - def test_circuit_breaker_state_transitions(self): - """Test circuit breaker state transitions from closed to open.""" - breaker = CircuitBreakerManager.get_circuit_breaker("test-host") - - assert breaker.current_state == "closed" - - def failing_func(): - raise Exception("Simulated failure") - - # Trigger failures up to the threshold (MINIMUM_CALLS = 20) - for _ in range(MINIMUM_CALLS): - with pytest.raises(Exception): - breaker.call(failing_func) - - # Next call should fail with CircuitBreakerError (circuit is now open) - with pytest.raises(CircuitBreakerError): - breaker.call(failing_func) - - assert breaker.current_state == "open" - - def test_circuit_breaker_recovery(self): - """Test circuit breaker recovery after failures.""" - breaker = CircuitBreakerManager.get_circuit_breaker("test-host") - - def failing_func(): - raise Exception("Simulated failure") - - # Trigger failures up to the threshold - for _ in range(MINIMUM_CALLS): - with pytest.raises(Exception): - breaker.call(failing_func) - - assert breaker.current_state == "open" - - # Wait for reset timeout - time.sleep(RESET_TIMEOUT + 1.0) - - # Try successful call to close circuit breaker - def successful_func(): - return "success" - - try: - result = breaker.call(successful_func) - assert result == "success" - except CircuitBreakerError: - pass # Circuit might still be open, acceptable - - assert breaker.current_state in ["closed", "half-open", "open"] - - @pytest.mark.parametrize("old_state,new_state", [ - ("closed", "open"), - ("open", "half-open"), - ("half-open", "closed"), - ("closed", "half-open"), - ]) - def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): - """Test circuit breaker state listener logs all state transitions.""" - from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerStateListener, - ) - - listener = CircuitBreakerStateListener() - mock_cb = Mock() - mock_cb.name = "test-breaker" - - mock_old_state = Mock() - mock_old_state.name = old_state - - mock_new_state = Mock() - mock_new_state.name = new_state - - with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: - listener.state_change(mock_cb, mock_old_state, mock_new_state) - mock_logger.info.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b515756e8..19375cde3 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -22,13 +22,7 @@ import databricks.sql import databricks.sql.client as client -from databricks.sql import ( - InterfaceError, - DatabaseError, - Error, - NotSupportedError, - TransactionError, -) +from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState @@ -445,6 +439,11 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): "last operation", ) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_commit_a_noop(self, mock_thrift_backend_class): + c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + c.commit() + def test_setinputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setinputsizes(1) @@ -453,6 +452,12 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_rollback_not_supported(self, mock_thrift_backend_class): + c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + with self.assertRaises(NotSupportedError): + c.rollback() + @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): @@ -634,469 +639,11 @@ def mock_close_normal(): ) -class TransactionTestSuite(unittest.TestCase): - """ - Unit tests for transaction control methods (MST support). - """ - - PACKAGE_NAME = "databricks.sql" - DUMMY_CONNECTION_ARGS = { - "server_hostname": "foo", - "http_path": "dummy_path", - "access_token": "tok", - } - - def _create_mock_connection(self, mock_session_class): - """Helper to create a mocked connection for transaction tests.""" - # Mock session - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session.get_autocommit.return_value = True - mock_session_class.return_value = mock_session - - # Create connection with ignore_transactions=False to test actual transaction functionality - conn = client.Connection( - ignore_transactions=False, **self.DUMMY_CONNECTION_ARGS - ) - return conn - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_getter_returns_cached_value(self, mock_session_class): - """Test that autocommit property returns cached session value by default.""" - conn = self._create_mock_connection(mock_session_class) - - # Get autocommit (should use cached value) - result = conn.autocommit - - conn.session.get_autocommit.assert_called_once() - self.assertTrue(result) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_setter_executes_sql(self, mock_session_class): - """Test that setting autocommit executes SET AUTOCOMMIT command.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - conn.autocommit = False - - # Verify SQL was executed - mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = FALSE") - mock_cursor.close.assert_called_once() - - conn.session.set_autocommit.assert_called_once_with(False) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_setter_with_true_value(self, mock_session_class): - """Test setting autocommit to True.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - conn.autocommit = True - - mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = TRUE") - conn.session.set_autocommit.assert_called_once_with(True) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_setter_wraps_database_error(self, mock_session_class): - """Test that autocommit setter wraps DatabaseError in TransactionError.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - server_error = DatabaseError( - "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", - context={"sql_state": "25000"}, - session_id_hex="test-session-id", - ) - mock_cursor.execute.side_effect = server_error - - with patch.object(conn, "cursor", return_value=mock_cursor): - with self.assertRaises(TransactionError) as ctx: - conn.autocommit = False - - self.assertIn("Failed to set autocommit", str(ctx.exception)) - self.assertEqual(ctx.exception.context["operation"], "set_autocommit") - self.assertEqual(ctx.exception.context["autocommit_value"], False) - - mock_cursor.close.assert_called_once() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): - """Test that exception chaining is preserved.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - original_error = DatabaseError( - "Original error", session_id_hex="test-session-id" - ) - mock_cursor.execute.side_effect = original_error - - with patch.object(conn, "cursor", return_value=mock_cursor): - with self.assertRaises(TransactionError) as ctx: - conn.autocommit = False - - self.assertEqual(ctx.exception.__cause__, original_error) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_commit_executes_sql(self, mock_session_class): - """Test that commit() executes COMMIT command.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - conn.commit() - - mock_cursor.execute.assert_called_once_with("COMMIT") - mock_cursor.close.assert_called_once() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_commit_wraps_database_error(self, mock_session_class): - """Test that commit() wraps DatabaseError in TransactionError.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - server_error = DatabaseError( - "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", - context={"sql_state": "25000"}, - session_id_hex="test-session-id", - ) - mock_cursor.execute.side_effect = server_error - - with patch.object(conn, "cursor", return_value=mock_cursor): - with self.assertRaises(TransactionError) as ctx: - conn.commit() - - self.assertIn("Failed to commit", str(ctx.exception)) - self.assertEqual(ctx.exception.context["operation"], "commit") - mock_cursor.close.assert_called_once() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_commit_on_closed_connection_raises_interface_error( - self, mock_session_class - ): - """Test that commit() on closed connection raises InterfaceError.""" - conn = self._create_mock_connection(mock_session_class) - conn.session.is_open = False - - with self.assertRaises(InterfaceError) as ctx: - conn.commit() - - self.assertIn("Cannot commit on closed connection", str(ctx.exception)) - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_rollback_executes_sql(self, mock_session_class): - """Test that rollback() executes ROLLBACK command.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - conn.rollback() - - mock_cursor.execute.assert_called_once_with("ROLLBACK") - mock_cursor.close.assert_called_once() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_rollback_wraps_database_error(self, mock_session_class): - """Test that rollback() wraps DatabaseError in TransactionError.""" - conn = self._create_mock_connection(mock_session_class) - - mock_cursor = Mock() - server_error = DatabaseError( - "Unexpected rollback error", - context={"sql_state": "HY000"}, - session_id_hex="test-session-id", - ) - mock_cursor.execute.side_effect = server_error - - with patch.object(conn, "cursor", return_value=mock_cursor): - with self.assertRaises(TransactionError) as ctx: - conn.rollback() - - self.assertIn("Failed to rollback", str(ctx.exception)) - self.assertEqual(ctx.exception.context["operation"], "rollback") - mock_cursor.close.assert_called_once() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_rollback_on_closed_connection_raises_interface_error( - self, mock_session_class - ): - """Test that rollback() on closed connection raises InterfaceError.""" - conn = self._create_mock_connection(mock_session_class) - conn.session.is_open = False - - with self.assertRaises(InterfaceError) as ctx: - conn.rollback() - - self.assertIn("Cannot rollback on closed connection", str(ctx.exception)) - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_get_transaction_isolation_returns_repeatable_read( - self, mock_session_class - ): - """Test that get_transaction_isolation() returns REPEATABLE_READ.""" - conn = self._create_mock_connection(mock_session_class) - - result = conn.get_transaction_isolation() - - self.assertEqual(result, "REPEATABLE_READ") - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_get_transaction_isolation_on_closed_connection_raises_interface_error( - self, mock_session_class - ): - """Test that get_transaction_isolation() on closed connection raises InterfaceError.""" - conn = self._create_mock_connection(mock_session_class) - conn.session.is_open = False - - with self.assertRaises(InterfaceError) as ctx: - conn.get_transaction_isolation() - - self.assertIn( - "Cannot get transaction isolation on closed connection", str(ctx.exception) - ) - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_set_transaction_isolation_accepts_repeatable_read( - self, mock_session_class - ): - """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" - conn = self._create_mock_connection(mock_session_class) - - # Should not raise - conn.set_transaction_isolation("REPEATABLE_READ") - conn.set_transaction_isolation("REPEATABLE READ") # With space - conn.set_transaction_isolation("repeatable_read") # Lowercase with underscore - conn.set_transaction_isolation("repeatable read") # Lowercase with space - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_set_transaction_isolation_rejects_other_levels(self, mock_session_class): - """Test that set_transaction_isolation() rejects non-REPEATABLE_READ levels.""" - conn = self._create_mock_connection(mock_session_class) - - with self.assertRaises(NotSupportedError) as ctx: - conn.set_transaction_isolation("READ_COMMITTED") - - self.assertIn("not supported", str(ctx.exception)) - self.assertIn("READ_COMMITTED", str(ctx.exception)) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_set_transaction_isolation_on_closed_connection_raises_interface_error( - self, mock_session_class - ): - """Test that set_transaction_isolation() on closed connection raises InterfaceError.""" - conn = self._create_mock_connection(mock_session_class) - conn.session.is_open = False - - with self.assertRaises(InterfaceError) as ctx: - conn.set_transaction_isolation("REPEATABLE_READ") - - self.assertIn( - "Cannot set transaction isolation on closed connection", str(ctx.exception) - ) - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): - """Test that fetch_autocommit_from_server=True queries server.""" - # Create connection with fetch_autocommit_from_server=True - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - conn = client.Connection( - fetch_autocommit_from_server=True, - ignore_transactions=False, - **self.DUMMY_CONNECTION_ARGS, - ) - - mock_cursor = Mock() - mock_row = Mock() - mock_row.__getitem__ = Mock(return_value="true") - mock_cursor.fetchone.return_value = mock_row - - with patch.object(conn, "cursor", return_value=mock_cursor): - result = conn.autocommit - - mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT") - mock_cursor.fetchone.assert_called_once() - mock_cursor.close.assert_called_once() - - conn.session.set_autocommit.assert_called_once_with(True) - - self.assertTrue(result) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_class): - """Test that fetch_autocommit_from_server correctly parses false value.""" - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - conn = client.Connection( - fetch_autocommit_from_server=True, - ignore_transactions=False, - **self.DUMMY_CONNECTION_ARGS, - ) - - mock_cursor = Mock() - mock_row = Mock() - mock_row.__getitem__ = Mock(return_value="false") - mock_cursor.fetchone.return_value = mock_row - - with patch.object(conn, "cursor", return_value=mock_cursor): - result = conn.autocommit - - conn.session.set_autocommit.assert_called_once_with(False) - self.assertFalse(result) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_class): - """Test that fetch_autocommit_from_server raises error when no result.""" - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - conn = client.Connection( - fetch_autocommit_from_server=True, - ignore_transactions=False, - **self.DUMMY_CONNECTION_ARGS, - ) - - mock_cursor = Mock() - mock_cursor.fetchone.return_value = None - - with patch.object(conn, "cursor", return_value=mock_cursor): - with self.assertRaises(TransactionError) as ctx: - _ = conn.autocommit - - self.assertIn("No result returned", str(ctx.exception)) - mock_cursor.close.assert_called_once() - - conn.close() - - # ==================== IGNORE_TRANSACTIONS TESTS ==================== - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): - """Test that commit() is a no-op when ignore_transactions=True.""" - - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - # Create connection with ignore_transactions=True (default) - conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) - - # Verify ignore_transactions is True by default - self.assertTrue(conn.ignore_transactions) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - # Call commit - should be no-op - conn.commit() - - # Verify that execute was NOT called (no-op) - mock_cursor.execute.assert_not_called() - mock_cursor.close.assert_not_called() - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_rollback_raises_not_supported_when_ignore_transactions_true( - self, mock_session_class - ): - """Test that rollback() raises NotSupportedError when ignore_transactions=True.""" - - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - # Create connection with ignore_transactions=True (default) - conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) - - # Verify ignore_transactions is True by default - self.assertTrue(conn.ignore_transactions) - - # Call rollback - should raise NotSupportedError - with self.assertRaises(NotSupportedError) as ctx: - conn.rollback() - - self.assertIn("Transactions are not supported", str(ctx.exception)) - - conn.close() - - @patch("%s.client.Session" % PACKAGE_NAME) - def test_autocommit_setter_is_noop_when_ignore_transactions_true( - self, mock_session_class - ): - """Test that autocommit setter is a no-op when ignore_transactions=True.""" - - mock_session = Mock() - mock_session.is_open = True - mock_session.guid_hex = "test-session-id" - mock_session_class.return_value = mock_session - - # Create connection with ignore_transactions=True (default) - conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) - - # Verify ignore_transactions is True by default - self.assertTrue(conn.ignore_transactions) - - mock_cursor = Mock() - with patch.object(conn, "cursor", return_value=mock_cursor): - # Set autocommit - should be no-op - conn.autocommit = False - - # Verify that execute was NOT called (no-op) - mock_cursor.execute.assert_not_called() - mock_cursor.close.assert_not_called() - - # Session set_autocommit should also not be called - conn.session.set_autocommit.assert_not_called() - - conn.close() - - if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() test_classes = [ ClientTestSuite, - TransactionTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 96a2f87d8..2ff82cee5 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,7 +2,6 @@ import pytest from unittest.mock import patch, MagicMock import json -from dataclasses import asdict from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -10,20 +9,7 @@ TelemetryClientFactory, TelemetryHelper, ) -from databricks.sql.common.feature_flag import ( - FeatureFlagsContextFactory, - FeatureFlagsContext, -) -from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType -from databricks.sql.telemetry.models.event import ( - TelemetryEvent, - DriverConnectionParameters, - DriverSystemConfiguration, - SqlExecutionEvent, - DriverErrorInfo, - DriverVolumeOperation, - HostDetails, -) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -41,9 +27,7 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -86,12 +70,12 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() - assert client._events_queue.qsize() == 2 + assert len(client._events_batch) == 2 # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() - assert client._events_queue.qsize() == 0 # Queue cleared after flush + assert len(client._events_batch) == 0 # Batch cleared after flush @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_network_request_flow(self, mock_http_request, mock_telemetry_client): @@ -101,7 +85,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -237,9 +221,7 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -307,9 +289,7 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch( - "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" - ): + with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -392,10 +372,8 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -422,10 +400,8 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -452,10 +428,8 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = ( - False # Connection starts closed for test cleanup - ) - + mock_session_instance.is_open = False # Connection starts closed for test cleanup + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -472,416 +446,3 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) - - -class TestTelemetryEventModels: - """Tests for telemetry event model data structures and JSON serialization.""" - - def test_host_details_serialization(self): - """Test HostDetails model serialization.""" - host = HostDetails(host_url="test-host.com", port=443) - - # Test JSON string generation - json_str = host.to_json() - assert isinstance(json_str, str) - parsed = json.loads(json_str) - assert parsed["host_url"] == "test-host.com" - assert parsed["port"] == 443 - - def test_driver_connection_parameters_all_fields(self): - """Test DriverConnectionParameters with all fields populated.""" - host_info = HostDetails(host_url="workspace.databricks.com", port=443) - proxy_info = HostDetails(host_url="proxy.company.com", port=8080) - cf_proxy_info = HostDetails(host_url="cf-proxy.company.com", port=8080) - - params = DriverConnectionParameters( - http_path="/sql/1.0/warehouses/abc123", - mode=DatabricksClientType.SEA, - host_info=host_info, - auth_mech=AuthMech.OAUTH, - auth_flow=AuthFlow.BROWSER_BASED_AUTHENTICATION, - socket_timeout=30000, - azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", - azure_tenant_id="tenant-123", - use_proxy=True, - use_system_proxy=True, - proxy_host_info=proxy_info, - use_cf_proxy=False, - cf_proxy_host_info=cf_proxy_info, - non_proxy_hosts=["localhost", "127.0.0.1"], - allow_self_signed_support=False, - use_system_trust_store=True, - enable_arrow=True, - enable_direct_results=True, - enable_sea_hybrid_results=True, - http_connection_pool_size=100, - rows_fetched_per_block=100000, - async_poll_interval_millis=2000, - support_many_parameters=True, - enable_complex_datatype_support=True, - allowed_volume_ingestion_paths="/Volumes/catalog/schema/volume", - ) - - # Serialize to JSON and parse back - json_str = params.to_json() - json_dict = json.loads(json_str) - - # Verify all new fields are in JSON - assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" - assert json_dict["mode"] == "SEA" - assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" - assert json_dict["auth_mech"] == "OAUTH" - assert json_dict["auth_flow"] == "BROWSER_BASED_AUTHENTICATION" - assert json_dict["socket_timeout"] == 30000 - assert json_dict["azure_workspace_resource_id"] == "/subscriptions/test/resourceGroups/test" - assert json_dict["azure_tenant_id"] == "tenant-123" - assert json_dict["use_proxy"] is True - assert json_dict["use_system_proxy"] is True - assert json_dict["proxy_host_info"]["host_url"] == "proxy.company.com" - assert json_dict["use_cf_proxy"] is False - assert json_dict["cf_proxy_host_info"]["host_url"] == "cf-proxy.company.com" - assert json_dict["non_proxy_hosts"] == ["localhost", "127.0.0.1"] - assert json_dict["allow_self_signed_support"] is False - assert json_dict["use_system_trust_store"] is True - assert json_dict["enable_arrow"] is True - assert json_dict["enable_direct_results"] is True - assert json_dict["enable_sea_hybrid_results"] is True - assert json_dict["http_connection_pool_size"] == 100 - assert json_dict["rows_fetched_per_block"] == 100000 - assert json_dict["async_poll_interval_millis"] == 2000 - assert json_dict["support_many_parameters"] is True - assert json_dict["enable_complex_datatype_support"] is True - assert json_dict["allowed_volume_ingestion_paths"] == "/Volumes/catalog/schema/volume" - - def test_driver_connection_parameters_minimal_fields(self): - """Test DriverConnectionParameters with only required fields.""" - host_info = HostDetails(host_url="workspace.databricks.com", port=443) - - params = DriverConnectionParameters( - http_path="/sql/1.0/warehouses/abc123", - mode=DatabricksClientType.THRIFT, - host_info=host_info, - ) - - # Note: to_json() filters out None values, so we need to check asdict for complete structure - json_str = params.to_json() - json_dict = json.loads(json_str) - - # Required fields should be present - assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" - assert json_dict["mode"] == "THRIFT" - assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" - - # Optional fields with None are filtered out by to_json() - # This is expected behavior - None values are excluded from JSON output - - def test_driver_system_configuration_serialization(self): - """Test DriverSystemConfiguration model serialization.""" - sys_config = DriverSystemConfiguration( - driver_name="Databricks SQL Connector for Python", - driver_version="3.0.0", - runtime_name="CPython", - runtime_version="3.11.0", - runtime_vendor="Python Software Foundation", - os_name="Darwin", - os_version="23.0.0", - os_arch="arm64", - char_set_encoding="utf-8", - locale_name="en_US", - client_app_name="MyApp", - ) - - json_str = sys_config.to_json() - json_dict = json.loads(json_str) - - assert json_dict["driver_name"] == "Databricks SQL Connector for Python" - assert json_dict["driver_version"] == "3.0.0" - assert json_dict["runtime_name"] == "CPython" - assert json_dict["runtime_version"] == "3.11.0" - assert json_dict["runtime_vendor"] == "Python Software Foundation" - assert json_dict["os_name"] == "Darwin" - assert json_dict["os_version"] == "23.0.0" - assert json_dict["os_arch"] == "arm64" - assert json_dict["locale_name"] == "en_US" - assert json_dict["char_set_encoding"] == "utf-8" - assert json_dict["client_app_name"] == "MyApp" - - def test_telemetry_event_complete_serialization(self): - """Test complete TelemetryEvent serialization with all nested objects.""" - host_info = HostDetails(host_url="workspace.databricks.com", port=443) - proxy_info = HostDetails(host_url="proxy.company.com", port=8080) - - connection_params = DriverConnectionParameters( - http_path="/sql/1.0/warehouses/abc123", - mode=DatabricksClientType.SEA, - host_info=host_info, - auth_mech=AuthMech.OAUTH, - use_proxy=True, - proxy_host_info=proxy_info, - enable_arrow=True, - rows_fetched_per_block=100000, - ) - - sys_config = DriverSystemConfiguration( - driver_name="Databricks SQL Connector for Python", - driver_version="3.0.0", - runtime_name="CPython", - runtime_version="3.11.0", - runtime_vendor="Python Software Foundation", - os_name="Darwin", - os_version="23.0.0", - os_arch="arm64", - char_set_encoding="utf-8", - ) - - error_info = DriverErrorInfo( - error_name="ConnectionError", - stack_trace="Traceback...", - ) - - event = TelemetryEvent( - session_id="test-session-123", - sql_statement_id="test-stmt-456", - operation_latency_ms=1500, - auth_type="OAUTH", - system_configuration=sys_config, - driver_connection_params=connection_params, - error_info=error_info, - ) - - # Test JSON serialization - json_str = event.to_json() - assert isinstance(json_str, str) - - # Parse and verify structure - parsed = json.loads(json_str) - assert parsed["session_id"] == "test-session-123" - assert parsed["sql_statement_id"] == "test-stmt-456" - assert parsed["operation_latency_ms"] == 1500 - assert parsed["auth_type"] == "OAUTH" - - # Verify nested objects - assert parsed["system_configuration"]["driver_name"] == "Databricks SQL Connector for Python" - assert parsed["driver_connection_params"]["http_path"] == "/sql/1.0/warehouses/abc123" - assert parsed["driver_connection_params"]["use_proxy"] is True - assert parsed["driver_connection_params"]["proxy_host_info"]["host_url"] == "proxy.company.com" - assert parsed["error_info"]["error_name"] == "ConnectionError" - - def test_json_serialization_excludes_none_values(self): - """Test that JSON serialization properly excludes None values.""" - host_info = HostDetails(host_url="workspace.databricks.com", port=443) - - params = DriverConnectionParameters( - http_path="/sql/1.0/warehouses/abc123", - mode=DatabricksClientType.SEA, - host_info=host_info, - # All optional fields left as None - ) - - json_str = params.to_json() - parsed = json.loads(json_str) - - # Required fields present - assert parsed["http_path"] == "/sql/1.0/warehouses/abc123" - - # None values should be EXCLUDED from JSON (not included as null) - # This is the behavior of JsonSerializableMixin - assert "auth_mech" not in parsed - assert "azure_tenant_id" not in parsed - assert "proxy_host_info" not in parsed - - -@patch("databricks.sql.client.Session") -@patch("databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers") -class TestConnectionParameterTelemetry: - """Tests for connection parameter population in telemetry.""" - - def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_session): - """Test that proxy configuration is captured in telemetry.""" - mock_session_instance = MagicMock() - mock_session_instance.guid_hex = "test-session-proxy" - mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False - mock_session_instance.use_sea = True - mock_session_instance.port = 443 - mock_session_instance.host = "workspace.databricks.com" - mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: - conn = sql.connect( - server_hostname="workspace.databricks.com", - http_path="/sql/1.0/warehouses/test", - access_token="test-token", - enable_telemetry=True, - force_enable_telemetry=True, - ) - - # Verify export was called - mock_export.assert_called_once() - call_args = mock_export.call_args - - # Extract driver_connection_params - driver_params = call_args.kwargs.get("driver_connection_params") - assert driver_params is not None - assert isinstance(driver_params, DriverConnectionParameters) - - # Verify fields are populated - assert driver_params.http_path == "/sql/1.0/warehouses/test" - assert driver_params.mode == DatabricksClientType.SEA - assert driver_params.host_info.host_url == "workspace.databricks.com" - assert driver_params.host_info.port == 443 - - def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools, mock_session): - """Test that Azure-specific parameters are captured in telemetry.""" - mock_session_instance = MagicMock() - mock_session_instance.guid_hex = "test-session-azure" - mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False - mock_session_instance.use_sea = False - mock_session_instance.port = 443 - mock_session_instance.host = "workspace.azuredatabricks.net" - mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: - conn = sql.connect( - server_hostname="workspace.azuredatabricks.net", - http_path="/sql/1.0/warehouses/test", - access_token="test-token", - azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", - azure_tenant_id="tenant-123", - enable_telemetry=True, - force_enable_telemetry=True, - ) - - mock_export.assert_called_once() - driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - - # Verify Azure fields - assert driver_params.azure_workspace_resource_id == "/subscriptions/test/resourceGroups/test" - assert driver_params.azure_tenant_id == "tenant-123" - - def test_connection_populates_arrow_and_performance_params(self, mock_setup_pools, mock_session): - """Test that Arrow and performance parameters are captured in telemetry.""" - mock_session_instance = MagicMock() - mock_session_instance.guid_hex = "test-session-perf" - mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False - mock_session_instance.use_sea = True - mock_session_instance.port = 443 - mock_session_instance.host = "workspace.databricks.com" - mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: - # Import pyarrow availability check - try: - import pyarrow - arrow_available = True - except ImportError: - arrow_available = False - - conn = sql.connect( - server_hostname="workspace.databricks.com", - http_path="/sql/1.0/warehouses/test", - access_token="test-token", - pool_maxsize=200, - enable_telemetry=True, - force_enable_telemetry=True, - ) - - mock_export.assert_called_once() - driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - - # Verify performance fields - assert driver_params.enable_arrow == arrow_available - assert driver_params.enable_direct_results is True - assert driver_params.http_connection_pool_size == 200 - assert driver_params.rows_fetched_per_block == 100000 # DEFAULT_ARRAY_SIZE - assert driver_params.async_poll_interval_millis == 2000 - assert driver_params.support_many_parameters is True - - def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): - """Test that CloudFlare proxy fields default to False/None (not yet supported).""" - mock_session_instance = MagicMock() - mock_session_instance.guid_hex = "test-session-cfproxy" - mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False - mock_session_instance.use_sea = True - mock_session_instance.port = 443 - mock_session_instance.host = "workspace.databricks.com" - mock_session.return_value = mock_session_instance - - with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: - conn = sql.connect( - server_hostname="workspace.databricks.com", - http_path="/sql/1.0/warehouses/test", - access_token="test-token", - enable_telemetry=True, - force_enable_telemetry=True, - ) - - mock_export.assert_called_once() - driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - - # CF proxy not yet supported - should be False/None - assert driver_params.use_cf_proxy is False - assert driver_params.cf_proxy_host_info is None - - -class TestFeatureFlagsContextFactory: - """Tests for FeatureFlagsContextFactory host-level caching.""" - - @pytest.fixture(autouse=True) - def reset_factory(self): - """Reset factory state before/after each test.""" - FeatureFlagsContextFactory._context_map.clear() - if FeatureFlagsContextFactory._executor: - FeatureFlagsContextFactory._executor.shutdown(wait=False) - FeatureFlagsContextFactory._executor = None - yield - FeatureFlagsContextFactory._context_map.clear() - if FeatureFlagsContextFactory._executor: - FeatureFlagsContextFactory._executor.shutdown(wait=False) - FeatureFlagsContextFactory._executor = None - - @pytest.mark.parametrize( - "hosts,expected_contexts", - [ - (["host1.com", "host1.com"], 1), # Same host shares context - (["host1.com", "host2.com"], 2), # Different hosts get separate contexts - (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario - ], - ) - def test_host_level_caching(self, hosts, expected_contexts): - """Test that contexts are cached by host correctly.""" - contexts = [] - for host in hosts: - conn = MagicMock() - conn.session.host = host - conn.session.http_client = MagicMock() - contexts.append(FeatureFlagsContextFactory.get_instance(conn)) - - assert len(FeatureFlagsContextFactory._context_map) == expected_contexts - if expected_contexts == 1: - assert all(ctx is contexts[0] for ctx in contexts) - - def test_remove_instance_and_executor_cleanup(self): - """Test removal uses host key and cleans up executor when empty.""" - conn1 = MagicMock() - conn1.session.host = "host1.com" - conn1.session.http_client = MagicMock() - - conn2 = MagicMock() - conn2.session.host = "host2.com" - conn2.session.http_client = MagicMock() - - FeatureFlagsContextFactory.get_instance(conn1) - FeatureFlagsContextFactory.get_instance(conn2) - assert FeatureFlagsContextFactory._executor is not None - - FeatureFlagsContextFactory.remove_instance(conn1) - assert len(FeatureFlagsContextFactory._context_map) == 1 - assert FeatureFlagsContextFactory._executor is not None - - FeatureFlagsContextFactory.remove_instance(conn2) - assert len(FeatureFlagsContextFactory._context_map) == 0 - assert FeatureFlagsContextFactory._executor is None diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py deleted file mode 100644 index 0e9455e1f..000000000 --- a/tests/unit/test_telemetry_push_client.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Unit tests for telemetry push client functionality. -""" - -import pytest -from unittest.mock import Mock, patch - -from databricks.sql.telemetry.telemetry_push_client import ( - ITelemetryPushClient, - TelemetryPushClient, - CircuitBreakerTelemetryPushClient, -) -from databricks.sql.common.http import HttpMethod -from databricks.sql.exc import TelemetryRateLimitError -from pybreaker import CircuitBreakerError - - -class TestTelemetryPushClient: - """Test cases for TelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_http_client = Mock() - self.client = TelemetryPushClient(self.mock_http_client) - - def test_initialization(self): - """Test client initialization.""" - assert self.client._http_client == self.mock_http_client - - def test_request_delegates_to_http_client(self): - """Test that request delegates to underlying HTTP client.""" - mock_response = Mock() - self.mock_http_client.request.return_value = mock_response - - response = self.client.request(HttpMethod.POST, "https://test.com", {}) - - assert response == mock_response - self.mock_http_client.request.assert_called_once() - - -class TestCircuitBreakerTelemetryPushClient: - """Test cases for CircuitBreakerTelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_delegate = Mock(spec=ITelemetryPushClient) - self.host = "test-host.example.com" - self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - def test_initialization(self): - """Test client initialization.""" - assert self.client._delegate == self.mock_delegate - assert self.client._host == self.host - assert self.client._circuit_breaker is not None - - def test_request_success(self): - """Test successful request when circuit breaker is enabled.""" - mock_response = Mock() - self.mock_delegate.request.return_value = mock_response - - response = self.client.request(HttpMethod.POST, "https://test.com", {}) - - assert response == mock_response - self.mock_delegate.request.assert_called_once() - - def test_request_circuit_breaker_open(self): - """Test request when circuit breaker is open raises CircuitBreakerError.""" - with patch.object( - self.client._circuit_breaker, - "call", - side_effect=CircuitBreakerError("Circuit is open"), - ): - with pytest.raises(CircuitBreakerError): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - def test_request_other_error(self): - """Test request when other error occurs raises original exception.""" - self.mock_delegate.request.side_effect = ValueError("Network error") - - with pytest.raises(ValueError, match="Network error"): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - @pytest.mark.parametrize("status_code,expected_error", [ - (429, TelemetryRateLimitError), - (503, TelemetryRateLimitError), - ]) - def test_request_rate_limit_codes(self, status_code, expected_error): - """Test that rate-limit status codes raise TelemetryRateLimitError.""" - mock_response = Mock() - mock_response.status = status_code - self.mock_delegate.request.return_value = mock_response - - with pytest.raises(expected_error): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - def test_request_non_rate_limit_code(self): - """Test that non-rate-limit status codes return response.""" - mock_response = Mock() - mock_response.status = 500 - mock_response.data = b'Server error' - self.mock_delegate.request.return_value = mock_response - - response = self.client.request(HttpMethod.POST, "https://test.com", {}) - assert response is not None - assert response.status == 500 - - def test_rate_limit_error_logging(self): - """Test that rate limit errors are logged with circuit breaker context.""" - with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: - mock_response = Mock() - mock_response.status = 429 - self.mock_delegate.request.return_value = mock_response - - with pytest.raises(TelemetryRateLimitError): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - mock_logger.warning.assert_called() - warning_args = mock_logger.warning.call_args[0] - assert "429" in str(warning_args) - assert "circuit breaker" in warning_args[0] - - def test_other_error_logging(self): - """Test that other errors are logged during wrapping/unwrapping.""" - with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: - self.mock_delegate.request.side_effect = ValueError("Network error") - - with pytest.raises(ValueError, match="Network error"): - self.client.request(HttpMethod.POST, "https://test.com", {}) - - assert mock_logger.debug.call_count >= 1 - - -class TestCircuitBreakerTelemetryPushClientIntegration: - """Integration tests for CircuitBreakerTelemetryPushClient.""" - - def setup_method(self): - """Set up test fixtures.""" - self.mock_delegate = Mock() - self.host = "test-host.example.com" - from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager - CircuitBreakerManager._instances.clear() - - def test_circuit_breaker_opens_after_failures(self): - """Test that circuit breaker opens after repeated failures (429/503 errors).""" - from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerManager, - MINIMUM_CALLS, - ) - - CircuitBreakerManager._instances.clear() - client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - mock_response = Mock() - mock_response.status = 429 - self.mock_delegate.request.return_value = mock_response - - rate_limit_error_count = 0 - circuit_breaker_error_count = 0 - - for _ in range(MINIMUM_CALLS + 5): - try: - client.request(HttpMethod.POST, "https://test.com", {}) - except TelemetryRateLimitError: - rate_limit_error_count += 1 - except CircuitBreakerError: - circuit_breaker_error_count += 1 - - assert rate_limit_error_count >= MINIMUM_CALLS - 1 - assert circuit_breaker_error_count > 0 - - def test_circuit_breaker_recovers_after_success(self): - """Test that circuit breaker recovers after successful calls.""" - import time - from databricks.sql.telemetry.circuit_breaker_manager import ( - CircuitBreakerManager, - MINIMUM_CALLS, - RESET_TIMEOUT, - ) - - CircuitBreakerManager._instances.clear() - client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) - - # Trigger failures - mock_rate_limit_response = Mock() - mock_rate_limit_response.status = 429 - self.mock_delegate.request.return_value = mock_rate_limit_response - - for _ in range(MINIMUM_CALLS + 5): - try: - client.request(HttpMethod.POST, "https://test.com", {}) - except (TelemetryRateLimitError, CircuitBreakerError): - pass - - # Circuit should be open - with pytest.raises(CircuitBreakerError): - client.request(HttpMethod.POST, "https://test.com", {}) - - # Wait for reset timeout - time.sleep(RESET_TIMEOUT + 1.0) - - # Simulate success - mock_success_response = Mock() - mock_success_response.status = 200 - self.mock_delegate.request.return_value = mock_success_response - - response = client.request(HttpMethod.POST, "https://test.com", {}) - assert response is not None - assert response.status == 200 - - def test_urllib3_import_fallback(self): - """Test that the urllib3 import fallback works correctly.""" - from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse - assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py deleted file mode 100644 index aa31f6628..000000000 --- a/tests/unit/test_telemetry_request_error_handling.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Unit tests specifically for telemetry_push_client RequestError handling -with http-code context extraction for rate limiting detection. -""" - -import pytest -from unittest.mock import Mock - -from databricks.sql.telemetry.telemetry_push_client import ( - CircuitBreakerTelemetryPushClient, - TelemetryPushClient, -) -from databricks.sql.common.http import HttpMethod -from databricks.sql.exc import RequestError, TelemetryRateLimitError -from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager - - -class TestTelemetryPushClientRequestErrorHandling: - """Test RequestError handling and http-code context extraction.""" - - @pytest.fixture - def setup_circuit_breaker(self): - """Setup circuit breaker for testing.""" - CircuitBreakerManager._instances.clear() - yield - CircuitBreakerManager._instances.clear() - - @pytest.fixture - def mock_delegate(self): - """Create mock delegate client.""" - return Mock(spec=TelemetryPushClient) - - @pytest.fixture - def client(self, mock_delegate, setup_circuit_breaker): - """Create CircuitBreakerTelemetryPushClient instance.""" - return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") - - @pytest.mark.parametrize("status_code", [429, 503]) - def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): - """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" - request_error = RequestError("HTTP request failed", context={"http-code": status_code}) - mock_delegate.request.side_effect = request_error - - with pytest.raises(TelemetryRateLimitError): - client.request(HttpMethod.POST, "https://test.com", {}) - - @pytest.mark.parametrize("status_code", [500, 400, 404]) - def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): - """Test that RequestError with non-rate-limit codes raises original RequestError.""" - request_error = RequestError("HTTP request failed", context={"http-code": status_code}) - mock_delegate.request.side_effect = request_error - - with pytest.raises(RequestError, match="HTTP request failed"): - client.request(HttpMethod.POST, "https://test.com", {}) - - @pytest.mark.parametrize("context", [{}, None, "429"]) - def test_request_error_with_invalid_context(self, client, mock_delegate, context): - """Test RequestError with invalid/missing context raises original error.""" - request_error = RequestError("HTTP request failed") - if context == "429": - # Edge case: http-code as string instead of int - request_error.context = {"http-code": context} - else: - request_error.context = context - mock_delegate.request.side_effect = request_error - - with pytest.raises(RequestError, match="HTTP request failed"): - client.request(HttpMethod.POST, "https://test.com", {}) - - def test_request_error_missing_context_attribute(self, client, mock_delegate): - """Test RequestError without context attribute raises original error.""" - request_error = RequestError("HTTP request failed") - if hasattr(request_error, "context"): - delattr(request_error, "context") - mock_delegate.request.side_effect = request_error - - with pytest.raises(RequestError, match="HTTP request failed"): - client.request(HttpMethod.POST, "https://test.com", {}) - - def test_http_code_extraction_prioritization(self, client, mock_delegate): - """Test that http-code from RequestError context is correctly extracted.""" - request_error = RequestError( - "HTTP request failed after retries", context={"http-code": 503} - ) - mock_delegate.request.side_effect = request_error - - with pytest.raises(TelemetryRateLimitError): - client.request(HttpMethod.POST, "https://test.com", {}) - - def test_non_request_error_exceptions_raised(self, client, mock_delegate): - """Test that non-RequestError exceptions are wrapped then unwrapped.""" - generic_error = ValueError("Network timeout") - mock_delegate.request.side_effect = generic_error - - with pytest.raises(ValueError, match="Network timeout"): - client.request(HttpMethod.POST, "https://test.com", {}) diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py deleted file mode 100644 index 4e9ce1bbf..000000000 --- a/tests/unit/test_unified_http_client.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Unit tests for UnifiedHttpClient, specifically testing MaxRetryError handling -and HTTP status code extraction. -""" - -import pytest -from unittest.mock import Mock, patch -from urllib3.exceptions import MaxRetryError - -from databricks.sql.common.unified_http_client import UnifiedHttpClient -from databricks.sql.common.http import HttpMethod -from databricks.sql.exc import RequestError -from databricks.sql.auth.common import ClientContext -from databricks.sql.types import SSLOptions - - -class TestUnifiedHttpClientMaxRetryError: - """Test MaxRetryError handling and HTTP status code extraction.""" - - @pytest.fixture - def client_context(self): - """Create a minimal ClientContext for testing.""" - context = Mock(spec=ClientContext) - context.hostname = "https://test.databricks.com" - context.ssl_options = SSLOptions( - tls_verify=True, - tls_verify_hostname=True, - tls_trusted_ca_file=None, - tls_client_cert_file=None, - tls_client_cert_key_file=None, - tls_client_cert_key_password=None, - ) - context.socket_timeout = 30 - context.retry_stop_after_attempts_count = 3 - context.retry_delay_min = 1.0 - context.retry_delay_max = 10.0 - context.retry_stop_after_attempts_duration = 300.0 - context.retry_delay_default = 5.0 - context.retry_dangerous_codes = [] - context.proxy_auth_method = None - context.pool_connections = 10 - context.pool_maxsize = 20 - context.user_agent = "test-agent" - return context - - @pytest.fixture - def http_client(self, client_context): - """Create UnifiedHttpClient instance.""" - return UnifiedHttpClient(client_context) - - @pytest.mark.parametrize("status_code,path", [ - (429, "reason.response"), - (503, "reason.response"), - (500, "direct_response"), - ]) - def test_max_retry_error_with_status_codes(self, http_client, status_code, path): - """Test MaxRetryError with various status codes and response paths.""" - mock_pool = Mock() - max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - - if path == "reason.response": - max_retry_error.reason = Mock() - max_retry_error.reason.response = Mock() - max_retry_error.reason.response.status = status_code - else: # direct_response - max_retry_error.response = Mock() - max_retry_error.response.status = status_code - - with patch.object( - http_client._direct_pool_manager, "request", side_effect=max_retry_error - ): - with pytest.raises(RequestError) as exc_info: - http_client.request( - HttpMethod.POST, "http://test.com", headers={"test": "header"} - ) - - error = exc_info.value - assert hasattr(error, "context") - assert "http-code" in error.context - assert error.context["http-code"] == status_code - - @pytest.mark.parametrize("setup_func", [ - lambda e: None, # No setup - error with no attributes - lambda e: setattr(e, "reason", None), # reason=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None - lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr - ]) - def test_max_retry_error_missing_status(self, http_client, setup_func): - """Test MaxRetryError without status code (no crash, empty context).""" - mock_pool = Mock() - max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - setup_func(max_retry_error) - - with patch.object( - http_client._direct_pool_manager, "request", side_effect=max_retry_error - ): - with pytest.raises(RequestError) as exc_info: - http_client.request(HttpMethod.GET, "http://test.com") - - error = exc_info.value - assert error.context == {} - - def test_max_retry_error_prefers_reason_response(self, http_client): - """Test that e.reason.response.status is preferred over e.response.status.""" - mock_pool = Mock() - max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") - - # Set both structures with different status codes - max_retry_error.reason = Mock() - max_retry_error.reason.response = Mock() - max_retry_error.reason.response.status = 429 # Should use this - - max_retry_error.response = Mock() - max_retry_error.response.status = 500 # Should be ignored - - with patch.object( - http_client._direct_pool_manager, "request", side_effect=max_retry_error - ): - with pytest.raises(RequestError) as exc_info: - http_client.request(HttpMethod.GET, "http://test.com") - - error = exc_info.value - assert error.context["http-code"] == 429 - - def test_generic_exception_no_crash(self, http_client): - """Test that generic exceptions don't crash when checking for status code.""" - generic_error = Exception("Network error") - - with patch.object( - http_client._direct_pool_manager, "request", side_effect=generic_error - ): - with pytest.raises(RequestError) as exc_info: - http_client.request(HttpMethod.POST, "http://test.com") - - error = exc_info.value - assert "HTTP request error" in str(error) From 6187d06a462b6e7412be64c0589c56573ec6d0ac Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Mon, 1 Dec 2025 01:16:45 +0530 Subject: [PATCH 32/35] Reapply "Merge branch 'main' into close-conn" This reverts commit 7b4f6cbe6d51acf793e51d2f1ae629ca0db699d9. --- .github/workflows/daily-telemetry-e2e.yml | 87 +++ .github/workflows/integration.yml | 8 +- CHANGELOG.md | 8 + README.md | 6 + TRANSACTIONS.md | 387 ++++++++++++ examples/README.md | 1 + examples/transactions.py | 47 ++ poetry.lock | 36 +- pyproject.toml | 3 +- src/databricks/sql/__init__.py | 5 +- src/databricks/sql/auth/common.py | 2 + src/databricks/sql/backend/sea/backend.py | 2 +- src/databricks/sql/client.py | 343 +++++++++- src/databricks/sql/common/feature_flag.py | 8 +- .../sql/common/unified_http_client.py | 52 +- src/databricks/sql/exc.py | 38 ++ src/databricks/sql/session.py | 21 + .../sql/telemetry/circuit_breaker_manager.py | 112 ++++ .../sql/telemetry/latency_logger.py | 289 ++++----- src/databricks/sql/telemetry/models/event.py | 109 +++- .../sql/telemetry/telemetry_client.py | 126 ++-- .../sql/telemetry/telemetry_push_client.py | 201 ++++++ src/databricks/sql/utils.py | 3 + tests/e2e/test_circuit_breaker.py | 232 +++++++ tests/e2e/test_telemetry_e2e.py | 343 ++++++++++ tests/e2e/test_transactions.py | 598 ++++++++++++++++++ .../unit/test_circuit_breaker_http_client.py | 208 ++++++ tests/unit/test_circuit_breaker_manager.py | 160 +++++ tests/unit/test_client.py | 477 +++++++++++++- tests/unit/test_telemetry.py | 465 +++++++++++++- tests/unit/test_telemetry_push_client.py | 213 +++++++ .../test_telemetry_request_error_handling.py | 96 +++ tests/unit/test_unified_http_client.py | 136 ++++ 33 files changed, 4595 insertions(+), 227 deletions(-) create mode 100644 .github/workflows/daily-telemetry-e2e.yml create mode 100644 TRANSACTIONS.md create mode 100644 examples/transactions.py create mode 100644 src/databricks/sql/telemetry/circuit_breaker_manager.py create mode 100644 src/databricks/sql/telemetry/telemetry_push_client.py create mode 100644 tests/e2e/test_circuit_breaker.py create mode 100644 tests/e2e/test_telemetry_e2e.py create mode 100644 tests/e2e/test_transactions.py create mode 100644 tests/unit/test_circuit_breaker_http_client.py create mode 100644 tests/unit/test_circuit_breaker_manager.py create mode 100644 tests/unit/test_telemetry_push_client.py create mode 100644 tests/unit/test_telemetry_request_error_handling.py create mode 100644 tests/unit/test_unified_http_client.py diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml new file mode 100644 index 000000000..3d61cf177 --- /dev/null +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -0,0 +1,87 @@ +name: Daily Telemetry E2E Tests + +on: + schedule: + - cron: '0 0 * * 0' # Run every Sunday at midnight UTC + + workflow_dispatch: # Allow manual triggering + inputs: + test_pattern: + description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' + required: false + default: 'tests/e2e/test_telemetry_e2e.py' + type: string + +jobs: + telemetry-e2e-tests: + runs-on: ubuntu-latest + environment: azure-prod + + env: + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} + DATABRICKS_CATALOG: peco + DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} + + steps: + #---------------------------------------------- + # check-out repo and set-up python + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up python + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + #---------------------------------------------- + # ----- install & configure poetry ----- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + #---------------------------------------------- + # load cached venv if cache exists + #---------------------------------------------- + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + + #---------------------------------------------- + # install dependencies if cache does not exist + #---------------------------------------------- + - name: Install dependencies + run: poetry install --no-interaction --all-extras + + #---------------------------------------------- + # run telemetry E2E tests + #---------------------------------------------- + - name: Run telemetry E2E tests + run: | + TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" + echo "Running tests: $TEST_PATTERN" + poetry run python -m pytest $TEST_PATTERN -v -s + + #---------------------------------------------- + # upload test results on failure + #---------------------------------------------- + - name: Upload test results on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: telemetry-test-results + path: | + .pytest_cache/ + tests-unsafe.log + retention-days: 7 + diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9c9e30a24..ad5369997 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -54,5 +54,9 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run e2e tests - run: poetry run python -m pytest tests/e2e -n auto \ No newline at end of file + - name: Run e2e tests (excluding daily-only tests) + run: | + # Exclude telemetry E2E tests from PR runs (run daily instead) + poetry run python -m pytest tests/e2e \ + --ignore=tests/e2e/test_telemetry_e2e.py \ + -n auto \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa6bfb66..5b902e976 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +# 4.2.1 (2025-11-20) +- Ignore transactions by default (databricks/databricks-sql-python#711 by @jayantsing-db) + +# 4.2.0 (2025-11-14) +- Add multi-statement transaction support (databricks/databricks-sql-python#704 by @jayantsing-db) +- Add a workflow to parallelise the E2E tests (databricks/databricks-sql-python#697 by @msrathore-db) +- Bring Python telemetry event model consistent with JDBC (databricks/databricks-sql-python#701 by @nikhilsuri-db) + # 4.1.4 (2025-10-15) - Add support for Token Federation (databricks/databricks-sql-python#691 by @madhav-db) - Add metric view support (databricks/databricks-sql-python#688 by @shivam2680) diff --git a/README.md b/README.md index d57efda1f..ec82a3637 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,12 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789 > to authenticate the target Databricks user account and needs to open the browser for authentication. So it > can only run on the user's machine. +## Transaction Support + +The connector supports multi-statement transactions with manual commit/rollback control. Set `connection.autocommit = False` to disable autocommit mode, then use `connection.commit()` and `connection.rollback()` to control transactions. + +For detailed documentation, examples, and best practices, see **[TRANSACTIONS.md](TRANSACTIONS.md)**. + ## SQLAlchemy Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`. diff --git a/TRANSACTIONS.md b/TRANSACTIONS.md new file mode 100644 index 000000000..590c298c0 --- /dev/null +++ b/TRANSACTIONS.md @@ -0,0 +1,387 @@ +# Transaction Support + +The Databricks SQL Connector for Python supports multi-statement transactions (MST). This allows you to group multiple SQL statements into atomic units that either succeed completely or fail completely. + +## Autocommit Behavior + +By default, every SQL statement executes in its own transaction and commits immediately (autocommit mode). This is the standard behavior for most database connectors. + +```python +from databricks import sql + +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123" +) + +# Default: autocommit is True +print(connection.autocommit) # True + +# Each statement commits immediately +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Already committed - data is visible to other connections +``` + +To use explicit transactions, disable autocommit: + +```python +connection.autocommit = False + +# Now statements are grouped into a transaction +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Not committed yet - must call connection.commit() + +connection.commit() # Now it's visible +``` + +## Basic Transaction Operations + +### Committing Changes + +When autocommit is disabled, you must explicitly commit your changes: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO orders VALUES (1, 100.00)") + cursor.execute("INSERT INTO order_items VALUES (1, 'Widget', 2)") + connection.commit() # Both inserts succeed together +except Exception as e: + connection.rollback() # Neither insert is saved + raise +finally: + connection.autocommit = True # Restore default state +``` + +### Rolling Back Changes + +Use `rollback()` to discard all changes made in the current transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +cursor.execute("INSERT INTO accounts VALUES (1, 1000)") +cursor.execute("UPDATE accounts SET balance = balance - 500 WHERE id = 1") + +# Changed your mind? +connection.rollback() # All changes discarded +``` + +Note: Calling `rollback()` when autocommit is enabled is safe (it's a no-op), but calling `commit()` will raise a `TransactionError`. + +### Sequential Transactions + +After a commit or rollback, a new transaction starts automatically: + +```python +connection.autocommit = False + +# First transaction +cursor.execute("INSERT INTO logs VALUES (1, 'event1')") +connection.commit() + +# Second transaction starts automatically +cursor.execute("INSERT INTO logs VALUES (2, 'event2')") +connection.rollback() # Only the second insert is discarded +``` + +## Multi-Table Transactions + +Transactions span multiple tables atomically. Either all changes are committed, or all are rolled back: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + # Insert into multiple tables + cursor.execute("INSERT INTO customers VALUES (1, 'Alice')") + cursor.execute("INSERT INTO orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO shipments VALUES (1, 1, 'pending')") + + connection.commit() # All three inserts succeed atomically +except Exception as e: + connection.rollback() # All three inserts are discarded + raise +finally: + connection.autocommit = True # Restore default state +``` + +This is particularly useful for maintaining data consistency across related tables. + +## Transaction Isolation + +Databricks uses **Snapshot Isolation** (mapped to `REPEATABLE_READ` in standard SQL terminology). This means: + +- **Repeatable reads**: Once you read data in a transaction, subsequent reads will see the same data (even if other transactions modify it) +- **Atomic commits**: Changes are visible to other connections only after commit +- **Write serializability within a single table**: Concurrent writes to the same table will cause conflicts +- **Snapshot isolation across tables**: Concurrent writes to different tables can succeed + +### Getting the Isolation Level + +```python +level = connection.get_transaction_isolation() +print(level) # Output: REPEATABLE_READ +``` + +### Setting the Isolation Level + +Currently, only `REPEATABLE_READ` is supported: + +```python +from databricks import sql + +# Using the constant +connection.set_transaction_isolation(sql.TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ) + +# Or using a string +connection.set_transaction_isolation("REPEATABLE_READ") + +# Other levels will raise NotSupportedError +connection.set_transaction_isolation("READ_COMMITTED") # Raises NotSupportedError +``` + +### What Repeatable Read Means in Practice + +Within a transaction, you'll always see a consistent snapshot of the data: + +```python +connection.autocommit = False +cursor = connection.cursor() + +# First read +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance1 = cursor.fetchone()[0] # Returns 1000 + +# Another connection updates the balance +# (In a separate connection: UPDATE accounts SET balance = 500 WHERE id = 1) + +# Second read in the same transaction +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance2 = cursor.fetchone()[0] # Still returns 1000 (repeatable read!) + +connection.commit() + +# After commit, new transactions will see the updated value (500) +``` + +## Error Handling + +### Setting Autocommit During a Transaction + +You cannot change autocommit mode while a transaction is active: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO logs VALUES (1, 'data')") + + # This will raise TransactionError + connection.autocommit = True # Error: transaction is active + +except sql.TransactionError as e: + print(f"Cannot change autocommit: {e}") + connection.rollback() # Clean up the transaction +finally: + connection.autocommit = True # Now it's safe to restore +``` + +### Committing Without an Active Transaction + +If autocommit is enabled, there's no active transaction, so calling `commit()` will fail: + +```python +connection.autocommit = True # Default + +try: + connection.commit() # Raises TransactionError +except sql.TransactionError as e: + print(f"No active transaction: {e}") +``` + +However, `rollback()` is safe in this case (it's a no-op). + +### Recovering from Query Failures + +If a statement fails during a transaction, roll back and start a new transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO valid_table VALUES (1, 'data')") + cursor.execute("INSERT INTO nonexistent_table VALUES (2, 'data')") # Fails + connection.commit() +except Exception as e: + connection.rollback() # Discard the partial transaction + + # Log the error (with autocommit still disabled) + try: + cursor.execute("INSERT INTO error_log VALUES (1, 'Query failed')") + connection.commit() + except Exception: + connection.rollback() +finally: + connection.autocommit = True # Restore default state +``` + +## Querying Server State + +By default, the `autocommit` property returns a cached value for performance. If you need to query the server each time (for instance, when strong consistency is required): + +```python +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123", + fetch_autocommit_from_server=True +) + +# Each access queries the server +state = connection.autocommit # Executes "SET AUTOCOMMIT" query +``` + +This is generally not needed for normal usage. + +## Write Conflicts + +### Within a Single Table + +Databricks enforces **write serializability** within a single table. If two transactions try to modify the same table concurrently, one will fail: + +```python +# Connection 1 +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO accounts VALUES (1, 100)") + +# Connection 2 (concurrent) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO accounts VALUES (2, 200)") + +# First commit succeeds +conn1.commit() # OK + +# Second commit fails with concurrent write conflict +try: + conn2.commit() # Raises error about concurrent writes +except Exception as e: + conn2.rollback() + print(f"Concurrent write detected: {e}") +``` + +This happens even when the rows being modified are different. The conflict detection is at the table level. + +### Across Multiple Tables + +Concurrent writes to *different* tables can succeed. Each table tracks its own write conflicts independently: + +```python +# Connection 1: writes to table_a +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO table_a VALUES (1, 'data')") + +# Connection 2: writes to table_b (different table) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO table_b VALUES (1, 'data')") + +# Both commits succeed (different tables) +conn1.commit() # OK +conn2.commit() # Also OK +``` + +## Best Practices + +1. **Keep transactions short**: Long-running transactions can cause conflicts with other connections. Commit as soon as your atomic unit of work is complete. + +2. **Always handle exceptions**: Wrap transaction code in try/except/finally and call `rollback()` on errors. + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO table1 VALUES (1, 'data')") + cursor.execute("UPDATE table2 SET status = 'updated'") + connection.commit() +except Exception as e: + connection.rollback() + logger.error(f"Transaction failed: {e}") + raise +finally: + connection.autocommit = True # Restore default state +``` + +3. **Use context managers**: If you're writing helper functions, consider using a context manager pattern: + +```python +from contextlib import contextmanager + +@contextmanager +def transaction(connection): + connection.autocommit = False + try: + yield connection + connection.commit() + except Exception: + connection.rollback() + raise + finally: + connection.autocommit = True + +# Usage +with transaction(connection): + cursor = connection.cursor() + cursor.execute("INSERT INTO logs VALUES (1, 'message')") + # Auto-commits on success, auto-rolls back on exception +``` + +4. **Reset autocommit when done**: Use a `finally` block to restore autocommit to `True`. This is especially important if the connection is reused or part of a connection pool: + +```python +connection.autocommit = False +try: + # ... transaction code ... + connection.commit() +except Exception: + connection.rollback() + raise +finally: + connection.autocommit = True # Restore to default state +``` + +5. **Be aware of isolation semantics**: Remember that repeatable read means you see a snapshot from the start of your transaction. If you need to see recent changes from other transactions, commit your current transaction and start a new one. + +## Requirements + +To use transactions, you need: +- A Databricks SQL warehouse that supports Multi-Statement Transactions (MST) +- Tables created with the `delta.feature.catalogOwned-preview` table property: + +```sql +CREATE TABLE my_table (id INT, value STRING) +USING DELTA +TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') +``` + +## Related APIs + +- `connection.autocommit` - Get or set autocommit mode (boolean) +- `connection.commit()` - Commit the current transaction +- `connection.rollback()` - Roll back the current transaction +- `connection.get_transaction_isolation()` - Get the isolation level (returns `"REPEATABLE_READ"`) +- `connection.set_transaction_isolation(level)` - Validate/set isolation level (only `"REPEATABLE_READ"` supported) +- `sql.TransactionError` - Exception raised for transaction-specific errors + +All of these are extensions to [PEP 249](https://www.python.org/dev/peps/pep-0249/) (Python Database API Specification v2.0). diff --git a/examples/README.md b/examples/README.md index d73c58a6b..f52dede1d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,6 +31,7 @@ To run all of these examples you can clone the entire repository to your disk. O - **`query_execute.py`** connects to the `samples` database of your default catalog, runs a small query, and prints the result to screen. - **`insert_data.py`** adds a tables called `squares` to your default catalog and inserts one hundred rows of example data. Then it fetches this data and prints it to the screen. +- **`transactions.py`** demonstrates multi-statement transaction support with explicit commit/rollback control. Shows how to group multiple SQL statements into an atomic unit that either succeeds completely or fails completely. - **`query_cancel.py`** shows how to cancel a query assuming that you can access the `Cursor` executing that query from a different thread. This is necessary because `databricks-sql-connector` does not yet implement an asynchronous API; calling `.execute()` blocks the current thread until execution completes. Therefore, the connector can't cancel queries from the same thread where they began. - **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script. - **`m2m_oauth.py`** shows the simplest example of authenticating by using OAuth M2M (machine-to-machine) for service principal. diff --git a/examples/transactions.py b/examples/transactions.py new file mode 100644 index 000000000..6f58dbd2d --- /dev/null +++ b/examples/transactions.py @@ -0,0 +1,47 @@ +from databricks import sql +import os + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + # Disable autocommit to use explicit transactions + connection.autocommit = False + + with connection.cursor() as cursor: + try: + # Create tables for demonstration + cursor.execute("CREATE TABLE IF NOT EXISTS accounts (id int, balance int)") + cursor.execute( + "CREATE TABLE IF NOT EXISTS transfers (from_id int, to_id int, amount int)" + ) + connection.commit() + + # Start a new transaction - transfer money between accounts + cursor.execute("INSERT INTO accounts VALUES (1, 1000), (2, 500)") + cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1") + cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2") + cursor.execute("INSERT INTO transfers VALUES (1, 2, 100)") + + # Commit the transaction - all changes succeed together + connection.commit() + print("Transaction committed successfully") + + # Verify the results + cursor.execute("SELECT * FROM accounts ORDER BY id") + print("Accounts:", cursor.fetchall()) + + cursor.execute("SELECT * FROM transfers") + print("Transfers:", cursor.fetchall()) + + except Exception as e: + # Roll back on error - all changes are discarded + connection.rollback() + print(f"Transaction rolled back due to error: {e}") + raise + + finally: + # Restore autocommit to default state + connection.autocommit = True diff --git a/poetry.lock b/poetry.lock index 1a8074c2a..193efa109 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1348,6 +1348,38 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, + {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + +[[package]] +name = "pybreaker" +version = "1.4.1" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, + {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + [[package]] name = "pycparser" version = "2.22" @@ -1858,4 +1890,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" +content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" diff --git a/pyproject.toml b/pyproject.toml index c0eb8244d..61c248e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.1.4" +version = "4.2.1" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -26,6 +26,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" +pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 403a4d130..cd37e6ce1 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -8,6 +8,9 @@ paramstyle = "named" +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + import re from typing import TYPE_CHECKING @@ -68,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.1.4" +__version__ = "4.2.1" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 3e0be0d2b..a764b036d 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,6 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -83,6 +84,7 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent + self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 5f10f2df4..9c5c63033 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -157,7 +157,7 @@ def __init__( "_use_arrow_native_complex_types", True ) - self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", False) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) # Extract warehouse ID from http_path diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5bb191ca2..a7f802dcd 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,6 +9,7 @@ import json import os import decimal +from urllib.parse import urlparse from uuid import UUID from databricks.sql import __version__ @@ -20,6 +21,8 @@ InterfaceError, NotSupportedError, ProgrammingError, + TransactionError, + DatabaseError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -86,6 +89,9 @@ NO_NATIVE_PARAMS: List = [] +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + class Connection: def __init__( @@ -98,6 +104,7 @@ def __init__( catalog: Optional[str] = None, schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, + ignore_transactions: bool = True, **kwargs, ) -> None: """ @@ -206,6 +213,17 @@ def read(self) -> Optional[OAuthToken]: This allows 1. cursor.tables() to return METRIC_VIEW table type 2. cursor.columns() to return "measure" column type + :param fetch_autocommit_from_server: `bool`, optional (default is False) + When True, the connection.autocommit property queries the server for current state + using SET AUTOCOMMIT instead of returning cached value. + Set to True if autocommit might be changed by external means (e.g., external SQL commands). + When False (default), uses cached state for better performance. + :param ignore_transactions: `bool`, optional (default is True) + When True, transaction-related operations behave as follows: + - commit(): no-op (does nothing) + - rollback(): raises NotSupportedError + - autocommit setter: no-op (does nothing) + When False, transaction operations execute normally. """ # Internal arguments in **kwargs: @@ -304,6 +322,10 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) + self._fetch_autocommit_from_server = kwargs.get( + "fetch_autocommit_from_server", False + ) + self.ignore_transactions = ignore_transactions self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -322,6 +344,20 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex() ) + # Determine proxy usage + use_proxy = self.http_client.using_proxy() + proxy_host_info = None + if ( + use_proxy + and self.http_client.proxy_uri + and isinstance(self.http_client.proxy_uri, str) + ): + parsed = urlparse(self.http_client.proxy_uri) + proxy_host_info = HostDetails( + host_url=parsed.hostname or self.http_client.proxy_uri, + port=parsed.port or 8080, + ) + driver_connection_params = DriverConnectionParameters( http_path=http_path, mode=DatabricksClientType.SEA @@ -331,13 +367,31 @@ def read(self) -> Optional[OAuthToken]: auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider), auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider), socket_timeout=kwargs.get("_socket_timeout", None), + azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id", None), + azure_tenant_id=kwargs.get("azure_tenant_id", None), + use_proxy=use_proxy, + use_system_proxy=use_proxy, + proxy_host_info=proxy_host_info, + use_cf_proxy=False, # CloudFlare proxy not yet supported in Python + cf_proxy_host_info=None, # CloudFlare proxy not yet supported in Python + non_proxy_hosts=None, + allow_self_signed_support=kwargs.get("_tls_no_verify", False), + use_system_trust_store=True, # Python uses system SSL by default + enable_arrow=pyarrow is not None, + enable_direct_results=True, # Always enabled in Python + enable_sea_hybrid_results=kwargs.get("use_hybrid_disposition", False), + http_connection_pool_size=kwargs.get("pool_maxsize", None), + rows_fetched_per_block=DEFAULT_ARRAY_SIZE, + async_poll_interval_millis=2000, # Default polling interval + support_many_parameters=True, # Native parameters supported + enable_complex_datatype_support=_use_arrow_native_complex_types, + allowed_volume_ingestion_paths=self.staging_allowed_local_path, ) self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, ) - self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -473,15 +527,286 @@ def _close(self, close_cursors=True) -> None: if self.http_client: self.http_client.close() - def commit(self): - """No-op because Databricks does not support transactions""" - pass + @property + def autocommit(self) -> bool: + """ + Get auto-commit mode for this connection. - def rollback(self): - raise NotSupportedError( - "Transactions are not supported on Databricks", - session_id_hex=self.get_session_id_hex(), - ) + Extension to PEP 249. Returns cached value by default. + If fetch_autocommit_from_server=True was set during connection, + queries server for current state. + + Returns: + bool: True if auto-commit is enabled, False otherwise + + Raises: + InterfaceError: If connection is closed + TransactionError: If fetch_autocommit_from_server=True and query fails + """ + if not self.open: + raise InterfaceError( + "Cannot get autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + if self._fetch_autocommit_from_server: + return self._fetch_autocommit_state_from_server() + + return self.session.get_autocommit() + + @autocommit.setter + def autocommit(self, value: bool) -> None: + """ + Set auto-commit mode for this connection. + + Extension to PEP 249. Executes SET AUTOCOMMIT command on server. + + Args: + value: True to enable auto-commit, False to disable + + When ignore_transactions is True: + - This method is a no-op (does nothing) + + Raises: + InterfaceError: If connection is closed + TransactionError: If server rejects the change + """ + # No-op when ignore_transactions is True + if self.ignore_transactions: + return + + if not self.open: + raise InterfaceError( + "Cannot set autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Create internal cursor for transaction control + cursor = None + try: + cursor = self.cursor() + sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}" + cursor.execute(sql) + + # Update cached state on success + self.session.set_autocommit(value) + + except DatabaseError as e: + # Wrap in TransactionError with context + raise TransactionError( + f"Failed to set autocommit to {value}: {e.message}", + context={ + **e.context, + "operation": "set_autocommit", + "autocommit_value": value, + }, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def _fetch_autocommit_state_from_server(self) -> bool: + """ + Query server for current autocommit state using SET AUTOCOMMIT. + + Returns: + bool: Server's autocommit state + + Raises: + TransactionError: If query fails + """ + cursor = None + try: + cursor = self.cursor() + cursor.execute("SET AUTOCOMMIT") + + # Fetch result: should return row with value column + result = cursor.fetchone() + if result is None: + raise TransactionError( + "No result returned from SET AUTOCOMMIT query", + context={"operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) + + # Parse value (first column should be "true" or "false") + value_str = str(result[0]).lower() + autocommit_state = value_str == "true" + + # Update cache + self.session.set_autocommit(autocommit_state) + + return autocommit_state + + except TransactionError: + # Re-raise TransactionError as-is + raise + except DatabaseError as e: + # Wrap other DatabaseErrors + raise TransactionError( + f"Failed to fetch autocommit state from server: {e.message}", + context={**e.context, "operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def commit(self) -> None: + """ + Commit the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Commits the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - Server may throw error if no active transaction + + When ignore_transactions is True: + - This method is a no-op (does nothing) + + Raises: + InterfaceError: If connection is closed + TransactionError: If commit fails (e.g., no active transaction) + """ + # No-op when ignore_transactions is True + if self.ignore_transactions: + return + + if not self.open: + raise InterfaceError( + "Cannot commit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("COMMIT") + + except DatabaseError as e: + raise TransactionError( + f"Failed to commit transaction: {e.message}", + context={**e.context, "operation": "commit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def rollback(self) -> None: + """ + Rollback the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Rolls back the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - ROLLBACK is forgiving (no-op, doesn't throw exception) + + When ignore_transactions is True: + - Raises NotSupportedError + + Note: ROLLBACK is safe to call even without active transaction. + + Raises: + InterfaceError: If connection is closed + NotSupportedError: If ignore_transactions is True + TransactionError: If rollback fails + """ + # Raise NotSupportedError when ignore_transactions is True + if self.ignore_transactions: + raise NotSupportedError( + "Transactions are not supported on Databricks", + session_id_hex=self.get_session_id_hex(), + ) + + if not self.open: + raise InterfaceError( + "Cannot rollback on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("ROLLBACK") + + except DatabaseError as e: + raise TransactionError( + f"Failed to rollback transaction: {e.message}", + context={**e.context, "operation": "rollback"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def get_transaction_isolation(self) -> str: + """ + Get the transaction isolation level. + + Extension to PEP 249. + + Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation), + which is the default and only supported level. + + Returns: + str: "REPEATABLE_READ" - the transaction isolation level constant + + Raises: + InterfaceError: If connection is closed + """ + if not self.open: + raise InterfaceError( + "Cannot get transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ + + def set_transaction_isolation(self, level: str) -> None: + """ + Set transaction isolation level. + + Extension to PEP 249. + + Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation). + This method validates that the requested level is supported but does not + execute any SQL, as REPEATABLE_READ is the default server behavior. + + Args: + level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ" + (case-insensitive, underscores and spaces are interchangeable) + + Raises: + InterfaceError: If connection is closed + NotSupportedError: If isolation level not supported + """ + if not self.open: + raise InterfaceError( + "Cannot set transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Normalize and validate isolation level + normalized_level = level.upper().replace("_", " ") + + if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace( + "_", " " + ): + raise NotSupportedError( + f"Setting transaction isolation level '{level}' is not supported. " + f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8a1cf5bd5..032701f63 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -165,8 +165,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: cls._initialize() assert cls._executor is not None - # Use the unique session ID as the key - key = connection.get_session_id_hex() + # Cache at HOST level - share feature flags across connections to same host + # Feature flags are per-host, not per-session + key = connection.session.host if key not in cls._context_map: cls._context_map[key] = FeatureFlagsContext( connection, cls._executor, connection.session.http_client @@ -177,7 +178,8 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: def remove_instance(cls, connection: "Connection"): """Removes the context for a given connection and shuts down the executor if no clients remain.""" with cls._lock: - key = connection.get_session_id_hex() + # Use host as key to match get_instance + key = connection.session.host if key in cls._context_map: cls._context_map.pop(key, None) diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 7ccd69c54..d5f7d3c8d 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -28,6 +28,42 @@ logger = logging.getLogger(__name__) +def _extract_http_status_from_max_retry_error(e: MaxRetryError) -> Optional[int]: + """ + Extract HTTP status code from MaxRetryError if available. + + urllib3 structures MaxRetryError in different ways depending on the failure scenario: + - e.reason.response.status: Most common case when retries are exhausted + - e.response.status: Alternate structure in some scenarios + + Args: + e: MaxRetryError exception from urllib3 + + Returns: + HTTP status code as int if found, None otherwise + """ + # Try primary structure: e.reason.response.status + if ( + hasattr(e, "reason") + and e.reason is not None + and hasattr(e.reason, "response") + and e.reason.response is not None + ): + http_code = getattr(e.reason.response, "status", None) + if http_code is not None: + return http_code + + # Try alternate structure: e.response.status + if ( + hasattr(e, "response") + and e.response is not None + and hasattr(e.response, "status") + ): + return e.response.status + + return None + + class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. @@ -264,7 +300,16 @@ def request_context( yield response except MaxRetryError as e: logger.error("HTTP request failed after retries: %s", e) - raise RequestError(f"HTTP request failed: {e}") + + # Extract HTTP status code from MaxRetryError if available + http_code = _extract_http_status_from_max_retry_error(e) + + context = {} + if http_code is not None: + context["http-code"] = http_code + logger.error("HTTP request failed with status code: %d", http_code) + + raise RequestError(f"HTTP request failed: {e}", context=context) except Exception as e: logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") @@ -301,6 +346,11 @@ def using_proxy(self) -> bool: """Check if proxy support is available (not whether it's being used for a specific request).""" return self._proxy_pool_manager is not None + @property + def proxy_uri(self) -> Optional[str]: + """Get the configured proxy URI, if any.""" + return self._proxy_uri + def close(self): """Close the underlying connection pools.""" if self._direct_pool_manager: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 4a772c49b..24844d573 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -70,6 +70,23 @@ class NotSupportedError(DatabaseError): pass +class TransactionError(DatabaseError): + """ + Exception raised for transaction-specific errors. + + This exception is used when transaction control operations fail, such as: + - Setting autocommit mode (AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION) + - Committing a transaction (MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION) + - Rolling back a transaction + - Setting transaction isolation level + + The exception includes context about which transaction operation failed + and preserves the underlying cause via exception chaining. + """ + + pass + + ### Custom error classes ### class InvalidServerResponseError(OperationalError): """Thrown if the server does not set the initial namespace correctly""" @@ -126,3 +143,24 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class TelemetryRateLimitError(Exception): + """Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable. + This exception is used exclusively by the circuit breaker to track telemetry rate limiting events.""" + + +class TelemetryNonRateLimitError(Exception): + """Wrapper for telemetry errors that should NOT trigger circuit breaker. + + This exception wraps non-rate-limiting errors (network errors, timeouts, server errors, etc.) + and is excluded from circuit breaker failure counting. Only TelemetryRateLimitError should + open the circuit breaker. + + Attributes: + original_exception: The actual exception that occurred + """ + + def __init__(self, original_exception: Exception): + self.original_exception = original_exception + super().__init__(f"Non-rate-limit telemetry error: {original_exception}") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d8ba5d125..0f723d144 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -45,6 +45,9 @@ def __init__( self.schema = schema self.http_path = http_path + # Initialize autocommit state (JDBC default is True) + self._autocommit = True + user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -168,6 +171,24 @@ def guid_hex(self) -> str: """Get the session ID in hex format""" return self._session_id.hex_guid + def get_autocommit(self) -> bool: + """ + Get the cached autocommit state for this session. + + Returns: + bool: True if autocommit is enabled, False otherwise + """ + return self._autocommit + + def set_autocommit(self, value: bool) -> None: + """ + Update the cached autocommit state for this session. + + Args: + value: True to cache autocommit as enabled, False as disabled + """ + self._autocommit = value + def close(self) -> None: """Close the underlying session.""" logger.info("Closing session %s", self.guid_hex) diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py new file mode 100644 index 000000000..852f0d916 --- /dev/null +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -0,0 +1,112 @@ +""" +Circuit breaker implementation for telemetry requests. + +This module provides circuit breaker functionality to prevent telemetry failures +from impacting the main SQL operations. It uses pybreaker library to implement +the circuit breaker pattern. +""" + +import logging +import threading +from typing import Dict + +import pybreaker +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener + +from databricks.sql.exc import TelemetryNonRateLimitError + +logger = logging.getLogger(__name__) + +# Circuit Breaker Constants +MINIMUM_CALLS = 20 # Number of failures before circuit opens +RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit +NAME_PREFIX = "telemetry-circuit-breaker" + +# Circuit Breaker State Constants (used in logging) +CIRCUIT_BREAKER_STATE_OPEN = "open" +CIRCUIT_BREAKER_STATE_CLOSED = "closed" +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" + +# Logging Message Constants +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" +LOG_CIRCUIT_BREAKER_OPENED = ( + "Circuit breaker opened for %s - telemetry requests will be blocked" +) +LOG_CIRCUIT_BREAKER_CLOSED = ( + "Circuit breaker closed for %s - telemetry requests will be allowed" +) +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( + "Circuit breaker half-open for %s - testing telemetry requests" +) + + +class CircuitBreakerStateListener(CircuitBreakerListener): + """Listener for circuit breaker state changes.""" + + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: + """Called before the circuit breaker calls a function.""" + pass + + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: + """Called when a function called by the circuit breaker fails.""" + pass + + def success(self, cb: CircuitBreaker) -> None: + """Called when a function called by the circuit breaker succeeds.""" + pass + + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: + """Called when the circuit breaker state changes.""" + old_state_name = old_state.name if old_state else "None" + new_state_name = new_state.name if new_state else "None" + + logger.info( + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name + ) + + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + + +class CircuitBreakerManager: + """ + Manages circuit breaker instances for telemetry requests. + + Creates and caches circuit breaker instances per host to ensure telemetry + failures don't impact main SQL operations. + """ + + _instances: Dict[str, CircuitBreaker] = {} + _lock = threading.RLock() + + @classmethod + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Get or create a circuit breaker instance for the specified host. + + Args: + host: The hostname for which to get the circuit breaker + + Returns: + CircuitBreaker instance for the host + """ + with cls._lock: + if host not in cls._instances: + breaker = CircuitBreaker( + fail_max=MINIMUM_CALLS, + reset_timeout=RESET_TIMEOUT, + name=f"{NAME_PREFIX}-{host}", + exclude=[ + TelemetryNonRateLimitError + ], # Don't count these as failures + ) + # Add state change listener for logging + breaker.add_listener(CircuitBreakerStateListener()) + cls._instances[host] = breaker + logger.debug("Created circuit breaker for host: %s", host) + + return cls._instances[host] diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 12cacd851..36ebee2b8 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,6 +1,6 @@ import time import functools -from typing import Optional +from typing import Optional, Dict, Any import logging from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( @@ -11,127 +11,141 @@ logger = logging.getLogger(__name__) -class TelemetryExtractor: +def _extract_cursor_data(cursor) -> Dict[str, Any]: """ - Base class for extracting telemetry information from various object types. + Extract telemetry data directly from a Cursor object. - This class serves as a proxy that delegates attribute access to the wrapped object - while providing a common interface for extracting telemetry-related data. - """ - - def __init__(self, obj): - self._obj = obj - - def __getattr__(self, name): - return getattr(self._obj, name) - - def get_session_id_hex(self): - pass - - def get_statement_id(self): - pass - - def get_is_compressed(self): - pass - - def get_execution_result_format(self): - pass - - def get_retry_count(self): - pass - - def get_chunk_id(self): - pass + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + This eliminates object creation overhead and method call indirection. + Args: + cursor: The Cursor object to extract data from -class CursorExtractor(TelemetryExtractor): + Returns: + Dict with telemetry data (values may be None if extraction fails) """ - Telemetry extractor specialized for Cursor objects. - - Extracts telemetry information from database cursor objects, including - statement IDs, session information, compression settings, and result formats. + data = {} + + # Extract statement_id (query_id) - direct attribute access + try: + data["statement_id"] = cursor.query_id + except (AttributeError, Exception): + data["statement_id"] = None + + # Extract session_id_hex - direct method call + try: + data["session_id_hex"] = cursor.connection.get_session_id_hex() + except (AttributeError, Exception): + data["session_id_hex"] = None + + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = cursor.connection.lz4_compression + except (AttributeError, Exception): + data["is_compressed"] = False + + # Extract execution_result_format - inline logic + try: + if cursor.active_result_set is None: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + else: + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + + results = cursor.active_result_set.results + if isinstance(results, ColumnQueue): + data["execution_result"] = ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(results, CloudFetchQueue): + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(results, ArrowQueue): + data["execution_result"] = ExecutionResultFormat.INLINE_ARROW + else: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + except (AttributeError, Exception): + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + + # Extract retry_count - direct attribute access + try: + if hasattr(cursor.backend, "retry_policy") and cursor.backend.retry_policy: + data["retry_count"] = len(cursor.backend.retry_policy.history) + else: + data["retry_count"] = 0 + except (AttributeError, Exception): + data["retry_count"] = 0 + + # chunk_id is always None for Cursor + data["chunk_id"] = None + + return data + + +def _extract_result_set_handler_data(handler) -> Dict[str, Any]: """ + Extract telemetry data directly from a ResultSetDownloadHandler object. - def get_statement_id(self) -> Optional[str]: - return self.query_id - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.connection.lz4_compression - - def get_execution_result_format(self) -> ExecutionResultFormat: - if self.active_result_set is None: - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - if isinstance(self.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.active_result_set.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.active_result_set.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: - return len(self.backend.retry_policy.history) - return 0 - - def get_chunk_id(self): - return None + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + Args: + handler: The ResultSetDownloadHandler object to extract data from -class ResultSetDownloadHandlerExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSetDownloadHandler objects. + Returns: + Dict with telemetry data (values may be None if extraction fails) """ + data = {} - def get_session_id_hex(self) -> Optional[str]: - return self._obj.session_id_hex + # Extract session_id_hex - direct attribute access + try: + data["session_id_hex"] = handler.session_id_hex + except (AttributeError, Exception): + data["session_id_hex"] = None - def get_statement_id(self) -> Optional[str]: - return self._obj.statement_id + # Extract statement_id - direct attribute access + try: + data["statement_id"] = handler.statement_id + except (AttributeError, Exception): + data["statement_id"] = None - def get_is_compressed(self) -> bool: - return self._obj.settings.is_lz4_compressed + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = handler.settings.is_lz4_compressed + except (AttributeError, Exception): + data["is_compressed"] = False - def get_execution_result_format(self) -> ExecutionResultFormat: - return ExecutionResultFormat.EXTERNAL_LINKS + # execution_result is always EXTERNAL_LINKS for result set handlers + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> Optional[int]: - # standard requests and urllib3 libraries don't expose retry count - return None + # retry_count is not available for result set handlers + data["retry_count"] = None + + # Extract chunk_id - direct attribute access + try: + data["chunk_id"] = handler.chunk_id + except (AttributeError, Exception): + data["chunk_id"] = None - def get_chunk_id(self) -> Optional[int]: - return self._obj.chunk_id + return data -def get_extractor(obj): +def _extract_telemetry_data(obj) -> Optional[Dict[str, Any]]: """ - Factory function to create the appropriate telemetry extractor for an object. + Extract telemetry data from an object based on its type. - Determines the object type and returns the corresponding specialized extractor - that can extract telemetry information from that object type. + OPTIMIZATION: Returns a simple dict instead of creating wrapper objects. + This dict will be used to create the SqlExecutionEvent in the background thread. Args: - obj: The object to create an extractor for. Can be a Cursor, - ResultSetDownloadHandler, or any other object. + obj: The object to extract data from (Cursor, ResultSetDownloadHandler, etc.) Returns: - TelemetryExtractor: A specialized extractor instance: - - CursorExtractor for Cursor objects - - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - - None for all other objects + Dict with telemetry data, or None if object type is not supported """ - if obj.__class__.__name__ == "Cursor": - return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSetDownloadHandler": - return ResultSetDownloadHandlerExtractor(obj) + obj_type = obj.__class__.__name__ + + if obj_type == "Cursor": + return _extract_cursor_data(obj) + elif obj_type == "ResultSetDownloadHandler": + return _extract_result_set_handler_data(obj) else: - logger.debug("No extractor found for %s", obj.__class__.__name__) + logger.debug("No telemetry extraction available for %s", obj_type) return None @@ -143,12 +157,6 @@ def log_latency(statement_type: StatementType = StatementType.NONE): data about the operation, including latency, statement information, and execution context. - The decorator automatically: - - Measures execution time using high-precision performance counters - - Extracts telemetry information from the method's object (self) - - Creates a SqlExecutionEvent with execution details - - Sends the telemetry data asynchronously via TelemetryClient - Args: statement_type (StatementType): The type of SQL statement being executed. @@ -162,54 +170,49 @@ def execute(self, query): function: A decorator that wraps methods to add latency logging. Note: - The wrapped method's object (self) must be compatible with the - telemetry extractor system (e.g., Cursor or ResultSet objects). + The wrapped method's object (self) must be a Cursor or + ResultSetDownloadHandler for telemetry data extraction. """ def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.perf_counter() - result = None + start_time = time.monotonic() try: - result = func(self, *args, **kwargs) - return result + return func(self, *args, **kwargs) finally: - - def _safe_call(func_to_call): - """Calls a function and returns a default value on any exception.""" - try: - return func_to_call() - except Exception: - return None - - end_time = time.perf_counter() - duration_ms = int((end_time - start_time) * 1000) - - extractor = get_extractor(self) - - if extractor is not None: - session_id_hex = _safe_call(extractor.get_session_id_hex) - statement_id = _safe_call(extractor.get_statement_id) - - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call( - extractor.get_execution_result_format - ), - retry_count=_safe_call(extractor.get_retry_count), - chunk_id=_safe_call(extractor.get_chunk_id), - ) - - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=statement_id, - ) + duration_ms = int((time.monotonic() - start_time) * 1000) + + # Always log for debugging + logger.debug("%s completed in %dms", func.__name__, duration_ms) + + # Fast check: use cached telemetry_enabled flag from connection + # Avoids dictionary lookup + instance check on every operation + connection = getattr(self, "connection", None) + if connection and getattr(connection, "telemetry_enabled", False): + session_id_hex = connection.get_session_id_hex() + if session_id_hex: + # Telemetry enabled - extract and send + telemetry_data = _extract_telemetry_data(self) + if telemetry_data: + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=telemetry_data.get("is_compressed"), + execution_result=telemetry_data.get("execution_result"), + retry_count=telemetry_data.get("retry_count"), + chunk_id=telemetry_data.get("chunk_id"), + ) + + telemetry_client = ( + TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=telemetry_data.get("statement_id"), + ) return wrapper diff --git a/src/databricks/sql/telemetry/models/event.py b/src/databricks/sql/telemetry/models/event.py index c7f9d9d17..2e6f63a6f 100644 --- a/src/databricks/sql/telemetry/models/event.py +++ b/src/databricks/sql/telemetry/models/event.py @@ -38,6 +38,25 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech (AuthMech): The authentication mechanism used auth_flow (AuthFlow): The authentication flow type socket_timeout (int): Connection timeout in milliseconds + azure_workspace_resource_id (str): Azure workspace resource ID + azure_tenant_id (str): Azure tenant ID + use_proxy (bool): Whether proxy is being used + use_system_proxy (bool): Whether system proxy is being used + proxy_host_info (HostDetails): Proxy host details if configured + use_cf_proxy (bool): Whether CloudFlare proxy is being used + cf_proxy_host_info (HostDetails): CloudFlare proxy host details if configured + non_proxy_hosts (list): List of hosts that bypass proxy + allow_self_signed_support (bool): Whether self-signed certificates are allowed + use_system_trust_store (bool): Whether system trust store is used + enable_arrow (bool): Whether Arrow format is enabled + enable_direct_results (bool): Whether direct results are enabled + enable_sea_hybrid_results (bool): Whether SEA hybrid results are enabled + http_connection_pool_size (int): HTTP connection pool size + rows_fetched_per_block (int): Number of rows fetched per block + async_poll_interval_millis (int): Async polling interval in milliseconds + support_many_parameters (bool): Whether many parameters are supported + enable_complex_datatype_support (bool): Whether complex datatypes are supported + allowed_volume_ingestion_paths (str): Allowed paths for volume ingestion """ http_path: str @@ -46,6 +65,25 @@ class DriverConnectionParameters(JsonSerializableMixin): auth_mech: Optional[AuthMech] = None auth_flow: Optional[AuthFlow] = None socket_timeout: Optional[int] = None + azure_workspace_resource_id: Optional[str] = None + azure_tenant_id: Optional[str] = None + use_proxy: Optional[bool] = None + use_system_proxy: Optional[bool] = None + proxy_host_info: Optional[HostDetails] = None + use_cf_proxy: Optional[bool] = None + cf_proxy_host_info: Optional[HostDetails] = None + non_proxy_hosts: Optional[list] = None + allow_self_signed_support: Optional[bool] = None + use_system_trust_store: Optional[bool] = None + enable_arrow: Optional[bool] = None + enable_direct_results: Optional[bool] = None + enable_sea_hybrid_results: Optional[bool] = None + http_connection_pool_size: Optional[int] = None + rows_fetched_per_block: Optional[int] = None + async_poll_interval_millis: Optional[int] = None + support_many_parameters: Optional[bool] = None + enable_complex_datatype_support: Optional[bool] = None + allowed_volume_ingestion_paths: Optional[str] = None @dataclass @@ -111,6 +149,69 @@ class DriverErrorInfo(JsonSerializableMixin): stack_trace: str +@dataclass +class ChunkDetails(JsonSerializableMixin): + """ + Contains detailed metrics about chunk downloads during result fetching. + + These metrics are accumulated across all chunk downloads for a single statement. + + Attributes: + initial_chunk_latency_millis (int): Latency of the first chunk download + slowest_chunk_latency_millis (int): Latency of the slowest chunk download + total_chunks_present (int): Total number of chunks available + total_chunks_iterated (int): Number of chunks actually downloaded + sum_chunks_download_time_millis (int): Total time spent downloading all chunks + """ + + initial_chunk_latency_millis: Optional[int] = None + slowest_chunk_latency_millis: Optional[int] = None + total_chunks_present: Optional[int] = None + total_chunks_iterated: Optional[int] = None + sum_chunks_download_time_millis: Optional[int] = None + + +@dataclass +class ResultLatency(JsonSerializableMixin): + """ + Contains latency metrics for different phases of query execution. + + This tracks two distinct phases: + 1. result_set_ready_latency_millis: Time from query submission until results are available (execute phase) + - Set when execute() completes + 2. result_set_consumption_latency_millis: Time spent iterating/fetching results (fetch phase) + - Measured from first fetch call until no more rows available + - In Java: tracked via markResultSetConsumption(hasNext) method + - Records start time on first fetch, calculates total on last fetch + + Attributes: + result_set_ready_latency_millis (int): Time until query results are ready (execution phase) + result_set_consumption_latency_millis (int): Time spent fetching/consuming results (fetch phase) + + """ + + result_set_ready_latency_millis: Optional[int] = None + result_set_consumption_latency_millis: Optional[int] = None + + +@dataclass +class OperationDetail(JsonSerializableMixin): + """ + Contains detailed information about the operation being performed. + + Attributes: + n_operation_status_calls (int): Number of status polling calls made + operation_status_latency_millis (int): Total latency of all status calls + operation_type (str): Specific operation type (e.g., EXECUTE_STATEMENT, LIST_TABLES, CANCEL_STATEMENT) + is_internal_call (bool): Whether this is an internal driver operation + """ + + n_operation_status_calls: Optional[int] = None + operation_status_latency_millis: Optional[int] = None + operation_type: Optional[str] = None + is_internal_call: Optional[bool] = None + + @dataclass class SqlExecutionEvent(JsonSerializableMixin): """ @@ -122,7 +223,10 @@ class SqlExecutionEvent(JsonSerializableMixin): is_compressed (bool): Whether the result is compressed execution_result (ExecutionResultFormat): Format of the execution result retry_count (int): Number of retry attempts made - chunk_id (int): ID of the chunk if applicable + chunk_id (int): ID of the chunk if applicable (used for error tracking) + chunk_details (ChunkDetails): Aggregated chunk download metrics + result_latency (ResultLatency): Latency breakdown by execution phase + operation_detail (OperationDetail): Detailed operation information """ statement_type: StatementType @@ -130,6 +234,9 @@ class SqlExecutionEvent(JsonSerializableMixin): execution_result: ExecutionResultFormat retry_count: Optional[int] chunk_id: Optional[int] + chunk_details: Optional[ChunkDetails] = None + result_latency: Optional[ResultLatency] = None + operation_detail: Optional[OperationDetail] = None @dataclass diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index c7c4289ec..892485a4a 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,6 +3,9 @@ import logging import json from concurrent.futures import ThreadPoolExecutor, wait +from queue import Queue, Full +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -40,6 +43,11 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -108,18 +116,21 @@ def get_auth_flow(auth_provider): @staticmethod def is_telemetry_enabled(connection: "Connection") -> bool: + # Fast path: force enabled - skip feature flag fetch entirely if connection.force_enable_telemetry: return True - if connection.enable_telemetry: - context = FeatureFlagsContextFactory.get_instance(connection) - flag_value = context.get_flag_value( - TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False - ) - return str(flag_value).lower() == "true" - else: + # Fast path: disabled - no need to check feature flag + if not connection.enable_telemetry: return False + # Only fetch feature flags when enable_telemetry=True and not forced + context = FeatureFlagsContextFactory.get_instance(connection) + flag_value = context.get_flag_value( + TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False + ) + return str(flag_value).lower() == "true" + class NoopTelemetryClient(BaseTelemetryClient): """ @@ -165,23 +176,26 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, - telemetry_enabled, - session_id_hex, + telemetry_enabled: bool, + session_id_hex: str, auth_provider, - host_url, + host_url: str, executor, - batch_size, + batch_size: int, client_context, - ): + ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch = [] - self._lock = threading.RLock() self._pending_futures = set() + + # OPTIMIZATION: Use lock-free Queue instead of list + lock + # Queue is thread-safe internally and has better performance under concurrency + self._events_queue: Queue[TelemetryFrontendLog] = Queue(maxsize=batch_size * 2) + self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -189,12 +203,41 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) + # Create telemetry push client based on circuit breaker enabled flag + if client_context.telemetry_circuit_breaker_enabled: + # Create circuit breaker telemetry push client + # (circuit breakers created on-demand) + self._telemetry_push_client: ITelemetryPushClient = ( + CircuitBreakerTelemetryPushClient( + TelemetryPushClient(self._http_client), + host_url, + ) + ) + else: + # Circuit breaker disabled - use direct telemetry push client + self._telemetry_push_client = TelemetryPushClient(self._http_client) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) - with self._lock: - self._events_batch.append(event) - if len(self._events_batch) >= self._batch_size: + + # OPTIMIZATION: Use non-blocking put with queue + # No explicit lock needed - Queue is thread-safe internally + try: + self._events_queue.put_nowait(event) + except Full: + # Queue is full, trigger immediate flush + logger.debug("Event queue full, triggering flush") + self._flush() + # Try again after flush + try: + self._events_queue.put_nowait(event) + except Full: + # Still full, drop event (acceptable for telemetry) + logger.debug("Dropped telemetry event - queue still full") + + # Check if we should flush based on queue size + if self._events_queue.qsize() >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) @@ -202,9 +245,16 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # OPTIMIZATION: Drain queue without locks + # Collect all events currently in the queue + events_to_flush = [] + while not self._events_queue.empty(): + try: + event = self._events_queue.get_nowait() + events_to_flush.append(event) + except: + # Queue is empty + break if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) @@ -257,7 +307,7 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request( + response = self._telemetry_push_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response @@ -393,9 +443,9 @@ class TelemetryClientFactory: It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ - _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + _clients: Dict[str, BaseTelemetryClient] = ( + {} + ) # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -406,7 +456,7 @@ class TelemetryClientFactory: # Shared flush thread for all clients _flush_thread = None _flush_event = threading.Event() - _flush_interval_seconds = 90 + _flush_interval_seconds = 300 # 5 minutes DEFAULT_BATCH_SIZE = 100 @@ -496,21 +546,21 @@ def initialize_telemetry_client( session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[ - session_id_hex - ] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, - batch_size=batch_size, - client_context=client_context, + TelemetryClientFactory._clients[session_id_hex] = ( + TelemetryClient( + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + batch_size=batch_size, + client_context=client_context, + ) ) else: - TelemetryClientFactory._clients[ - session_id_hex - ] = NoopTelemetryClient() + TelemetryClientFactory._clients[session_id_hex] = ( + NoopTelemetryClient() + ) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py new file mode 100644 index 000000000..461a57738 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -0,0 +1,201 @@ +""" +Telemetry push client interface and implementations. + +This module provides an interface for telemetry push clients with two implementations: +1. TelemetryPushClient - Direct HTTP client implementation +2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation +""" + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +try: + from urllib3 import BaseHTTPResponse +except ImportError: + from urllib3 import HTTPResponse as BaseHTTPResponse +from pybreaker import CircuitBreakerError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import ( + TelemetryRateLimitError, + TelemetryNonRateLimitError, + RequestError, +) +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + +logger = logging.getLogger(__name__) + + +class ITelemetryPushClient(ABC): + """Interface for telemetry push clients.""" + + @abstractmethod + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request.""" + pass + + +class TelemetryPushClient(ITelemetryPushClient): + """Direct HTTP client implementation for telemetry requests.""" + + def __init__(self, http_client: UnifiedHttpClient): + """ + Initialize the telemetry push client. + + Args: + http_client: The underlying HTTP client + """ + self._http_client = http_client + logger.debug("TelemetryPushClient initialized") + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request using the underlying HTTP client.""" + return self._http_client.request(method, url, headers, **kwargs) + + +class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): + """Circuit breaker wrapper implementation for telemetry requests.""" + + def __init__(self, delegate: ITelemetryPushClient, host: str): + """ + Initialize the circuit breaker telemetry push client. + + Args: + delegate: The underlying telemetry push client to wrap + host: The hostname for circuit breaker identification + """ + self._delegate = delegate + self._host = host + + # Get circuit breaker for this host (creates if doesn't exist) + self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) + + logger.debug( + "CircuitBreakerTelemetryPushClient initialized for host %s", + host, + ) + + def _make_request_and_check_status( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]], + **kwargs, + ) -> BaseHTTPResponse: + """ + Make the request and check response status. + + Raises TelemetryRateLimitError for 429/503 (circuit breaker counts these). + Wraps other errors in TelemetryNonRateLimitError (circuit breaker excludes these). + + Args: + method: HTTP method + url: Request URL + headers: Request headers + **kwargs: Additional request parameters + + Returns: + HTTP response + + Raises: + TelemetryRateLimitError: For 429/503 status codes (circuit breaker counts) + TelemetryNonRateLimitError: For other errors (circuit breaker excludes) + """ + try: + response = self._delegate.request(method, url, headers, **kwargs) + + # Check for rate limiting or service unavailable + if response.status in [429, 503]: + logger.warning( + "Telemetry endpoint returned %d for host %s, triggering circuit breaker", + response.status, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry endpoint rate limited or unavailable: {response.status}" + ) + + return response + + except Exception as e: + # Don't catch TelemetryRateLimitError - let it propagate to circuit breaker + if isinstance(e, TelemetryRateLimitError): + raise + + # Check if it's a RequestError with rate limiting status code (exhausted retries) + if isinstance(e, RequestError): + http_code = ( + e.context.get("http-code") + if hasattr(e, "context") and e.context + else None + ) + + if http_code in [429, 503]: + logger.debug( + "Telemetry retries exhausted with status %d for host %s, triggering circuit breaker", + http_code, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry rate limited after retries: {http_code}" + ) + + # NOT rate limiting (500 errors, network errors, timeouts, etc.) + # Wrap in TelemetryNonRateLimitError so circuit breaker excludes it + logger.debug( + "Non-rate-limit telemetry error for host %s: %s, wrapping to exclude from circuit breaker", + self._host, + e, + ) + raise TelemetryNonRateLimitError(e) from e + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """ + Make an HTTP request with circuit breaker protection. + + Circuit breaker only opens for TelemetryRateLimitError (429/503 responses). + Other errors are wrapped in TelemetryNonRateLimitError and excluded from circuit breaker. + All exceptions propagate to caller (TelemetryClient callback handles them). + """ + try: + # Use circuit breaker to protect the request + # TelemetryRateLimitError will trigger circuit breaker + # TelemetryNonRateLimitError is excluded from circuit breaker + return self._circuit_breaker.call( + self._make_request_and_check_status, + method, + url, + headers, + **kwargs, + ) + + except TelemetryNonRateLimitError as e: + # Unwrap and re-raise original exception + # Circuit breaker didn't count this, but caller should handle it + logger.debug( + "Non-rate-limit telemetry error for host %s, re-raising original: %s", + self._host, + e.original_exception, + ) + raise e.original_exception from e + # All other exceptions (TelemetryRateLimitError, CircuitBreakerError) propagate as-is diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9f96e8743..b46784b10 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -922,4 +922,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), + telemetry_circuit_breaker_enabled=kwargs.get( + "_telemetry_circuit_breaker_enabled" + ), ) diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py new file mode 100644 index 000000000..45c494d19 --- /dev/null +++ b/tests/e2e/test_circuit_breaker.py @@ -0,0 +1,232 @@ +""" +E2E tests for circuit breaker functionality in telemetry. + +This test suite verifies: +1. Circuit breaker opens after rate limit failures (429/503) +2. Circuit breaker blocks subsequent calls while open +3. Circuit breaker does not trigger for non-rate-limit errors +4. Circuit breaker can be disabled via configuration flag +5. Circuit breaker closes after reset timeout + +Run with: + pytest tests/e2e/test_circuit_breaker.py -v -s +""" + +import time +from unittest.mock import patch, MagicMock + +import pytest +from pybreaker import STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN +from urllib3 import HTTPResponse + +import databricks.sql as sql +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +@pytest.fixture(autouse=True) +def aggressive_circuit_breaker_config(): + """ + Configure circuit breaker to be aggressive for faster testing. + Opens after 2 failures instead of 20, with 5 second timeout. + """ + from databricks.sql.telemetry import circuit_breaker_manager + + original_minimum_calls = circuit_breaker_manager.MINIMUM_CALLS + original_reset_timeout = circuit_breaker_manager.RESET_TIMEOUT + + circuit_breaker_manager.MINIMUM_CALLS = 2 + circuit_breaker_manager.RESET_TIMEOUT = 5 + + CircuitBreakerManager._instances.clear() + + yield + + circuit_breaker_manager.MINIMUM_CALLS = original_minimum_calls + circuit_breaker_manager.RESET_TIMEOUT = original_reset_timeout + CircuitBreakerManager._instances.clear() + + +class TestCircuitBreakerTelemetry: + """Tests for circuit breaker functionality with telemetry""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + """Get connection details from pytest fixture""" + self.arguments = connection_details.copy() + + def create_mock_response(self, status_code): + """Helper to create mock HTTP response.""" + response = MagicMock(spec=HTTPResponse) + response.status = status_code + response.data = { + 429: b"Too Many Requests", + 503: b"Service Unavailable", + 500: b"Internal Server Error", + }.get(status_code, b"Response") + return response + + @pytest.mark.parametrize("status_code,should_trigger", [ + (429, True), + (503, True), + (500, False), + ]) + def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + """ + Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). + """ + request_count = {"count": 0} + + def mock_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(status_code) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + assert circuit_breaker.current_state == STATE_CLOSED + + cursor = conn.cursor() + + # Execute queries to trigger telemetry + for i in range(1, 6): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.5) + + if should_trigger: + # Circuit should be OPEN after 2 rate-limit failures + assert circuit_breaker.current_state == STATE_OPEN + assert circuit_breaker.fail_counter == 2 + + # Track requests before another query + requests_before = request_count["count"] + cursor.execute("SELECT 99") + cursor.fetchone() + time.sleep(1) + + # No new telemetry requests (circuit is open) + assert request_count["count"] == requests_before + else: + # Circuit should remain CLOSED for non-rate-limit errors + assert circuit_breaker.current_state == STATE_CLOSED + assert circuit_breaker.fail_counter == 0 + assert request_count["count"] >= 5 + + def test_circuit_breaker_disabled_allows_all_calls(self): + """ + Verify that when circuit breaker is disabled, all calls go through + even with rate limit errors. + """ + request_count = {"count": 0} + + def mock_rate_limited_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(429) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_rate_limited_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=False, # Disabled + ) as conn: + cursor = conn.cursor() + + for i in range(5): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.3) + + assert request_count["count"] >= 5 + + def test_circuit_breaker_recovers_after_reset_timeout(self): + """ + Verify circuit breaker transitions to HALF_OPEN after reset timeout + and eventually CLOSES if requests succeed. + """ + request_count = {"count": 0} + fail_requests = {"enabled": True} + + def mock_conditional_request(*args, **kwargs): + request_count["count"] += 1 + status = 429 if fail_requests["enabled"] else 200 + return self.create_mock_response(status) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_conditional_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + cursor = conn.cursor() + + # Trigger failures to open circuit + cursor.execute("SELECT 1") + cursor.fetchone() + time.sleep(1) + + cursor.execute("SELECT 2") + cursor.fetchone() + time.sleep(2) + + assert circuit_breaker.current_state == STATE_OPEN + + # Wait for reset timeout (5 seconds in test) + time.sleep(6) + + # Now make requests succeed + fail_requests["enabled"] = False + + # Execute query to trigger HALF_OPEN state + cursor.execute("SELECT 3") + cursor.fetchone() + time.sleep(1) + + # Circuit should be recovering + assert circuit_breaker.current_state in [ + STATE_HALF_OPEN, + STATE_CLOSED, + ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + + # Execute more queries to fully recover + cursor.execute("SELECT 4") + cursor.fetchone() + time.sleep(1) + + current_state = circuit_breaker.current_state + assert current_state in [ + STATE_CLOSED, + STATE_HALF_OPEN, + ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py new file mode 100644 index 000000000..917c8e5eb --- /dev/null +++ b/tests/e2e/test_telemetry_e2e.py @@ -0,0 +1,343 @@ +""" +E2E test for telemetry - verifies telemetry behavior with different scenarios +""" +import time +import threading +import logging +from contextlib import contextmanager +from unittest.mock import patch +import pytest +from concurrent.futures import wait + +import databricks.sql as sql +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + TelemetryClientFactory, +) + +log = logging.getLogger(__name__) + + +class TelemetryTestBase: + """Simplified test base class for telemetry e2e tests""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + self.arguments = connection_details.copy() + + def connection_params(self): + return { + "server_hostname": self.arguments["host"], + "http_path": self.arguments["http_path"], + "access_token": self.arguments.get("access_token"), + } + + @contextmanager + def connection(self, extra_params=()): + connection_params = dict(self.connection_params(), **dict(extra_params)) + log.info("Connecting with args: {}".format(connection_params)) + conn = sql.connect(**connection_params) + try: + yield conn + finally: + conn.close() + + +class TestTelemetryE2E(TelemetryTestBase): + """E2E tests for telemetry scenarios""" + + @pytest.fixture(autouse=True) + def telemetry_setup_teardown(self): + """Clean up telemetry client state before and after each test""" + try: + yield + finally: + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._initialized = False + + @pytest.fixture + def telemetry_interceptors(self): + """Setup reusable telemetry interceptors as a fixture""" + capture_lock = threading.Lock() + captured_events = [] + captured_futures = [] + + original_export = TelemetryClient._export_event + original_callback = TelemetryClient._telemetry_request_callback + + def export_wrapper(self_client, event): + with capture_lock: + captured_events.append(event) + return original_export(self_client, event) + + def callback_wrapper(self_client, future, sent_count): + with capture_lock: + captured_futures.append(future) + original_callback(self_client, future, sent_count) + + return captured_events, captured_futures, export_wrapper, callback_wrapper + + # ==================== ASSERTION HELPERS ==================== + + def assert_system_config(self, event): + """Assert system configuration fields""" + sys_config = event.entry.sql_driver_log.system_configuration + assert sys_config is not None + + # Check all required fields are non-empty + for field in ['driver_name', 'driver_version', 'os_name', 'os_version', + 'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor', + 'locale_name', 'char_set_encoding']: + value = getattr(sys_config, field) + assert value and len(value) > 0, f"{field} should not be None or empty" + + assert sys_config.driver_name == "Databricks SQL Python Connector" + + def assert_connection_params(self, event, expected_http_path=None): + """Assert connection parameters""" + conn_params = event.entry.sql_driver_log.driver_connection_params + assert conn_params is not None + assert conn_params.http_path + assert conn_params.host_info is not None + assert conn_params.auth_mech is not None + + if expected_http_path: + assert conn_params.http_path == expected_http_path + + if conn_params.socket_timeout is not None: + assert conn_params.socket_timeout > 0 + + def assert_statement_execution(self, event): + """Assert statement execution details""" + sql_op = event.entry.sql_driver_log.sql_operation + assert sql_op is not None + assert sql_op.statement_type is not None + assert sql_op.execution_result is not None + assert hasattr(sql_op, "retry_count") + + if sql_op.retry_count is not None: + assert sql_op.retry_count >= 0 + + latency = event.entry.sql_driver_log.operation_latency_ms + assert latency is not None and latency >= 0 + + def assert_error_info(self, event, expected_error_name=None): + """Assert error information""" + error_info = event.entry.sql_driver_log.error_info + assert error_info is not None + assert error_info.error_name and len(error_info.error_name) > 0 + assert error_info.stack_trace and len(error_info.stack_trace) > 0 + + if expected_error_name: + assert error_info.error_name == expected_error_name + + def verify_events(self, captured_events, captured_futures, expected_count): + """Common verification for event count and HTTP responses""" + if expected_count == 0: + assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}" + assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}" + else: + assert len(captured_events) == expected_count, \ + f"Expected {expected_count} events, got {len(captured_events)}" + + time.sleep(2) + done, _ = wait(captured_futures, timeout=10) + assert len(done) == expected_count, \ + f"Expected {expected_count} responses, got {len(done)}" + + for future in done: + response = future.result() + assert 200 <= response.status < 300 + + # Assert common fields for all events + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + # ==================== PARAMETERIZED TESTS ==================== + + @pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [ + (True, False, 2, "enable_on_force_off"), + (False, True, 2, "enable_off_force_on"), + (False, False, 0, "both_off"), + (None, None, 0, "default_behavior"), + ]) + def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, + force_enable, expected_count, test_id): + """Test telemetry behavior with different flag combinations""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + extra_params = {"telemetry_batch_size": 1} + if enable_telemetry is not None: + extra_params["enable_telemetry"] = enable_telemetry + if force_enable is not None: + extra_params["force_enable_telemetry"] = force_enable + + with self.connection(extra_params=extra_params) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchone() + + self.verify_events(captured_events, captured_futures, expected_count) + + # Assert statement execution on latency event (if events exist) + if expected_count > 0: + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("query,expected_error", [ + ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), + ("SELECT * FROM non_existent_table_xyz_12345", None), + ]) + def test_sql_errors(self, telemetry_interceptors, query, expected_error): + """Test telemetry captures error information for different SQL errors""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + with pytest.raises(Exception): + cursor.execute(query) + cursor.fetchone() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + + # Find event with error_info + error_event = next((e for e in captured_events + if e.entry.sql_driver_log.error_info), None) + assert error_event is not None + + self.assert_system_config(error_event) + self.assert_connection_params(error_event, self.arguments["http_path"]) + self.assert_error_info(error_event, expected_error) + + def test_metadata_operation(self, telemetry_interceptors): + """Test telemetry for metadata operations (getCatalogs)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + catalogs = cursor.catalogs() + catalogs.fetchall() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + def test_direct_results(self, telemetry_interceptors): + """Test telemetry with direct results (use_cloud_fetch=False)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": False, + }) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 100") + result = cursor.fetchall() + assert len(result) == 1 and result[0][0] == 100 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("close_type", [ + "context_manager", + "explicit_cursor", + "explicit_connection", + "implicit_fetchall", + ]) + def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, + close_type): + """Test telemetry with cloud fetch using different resource closing patterns""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + if close_type == "explicit_connection": + # Test explicit connection close + conn = sql.connect( + **self.connection_params(), + force_enable_telemetry=True, + telemetry_batch_size=1, + use_cloud_fetch=True, + ) + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + conn.close() + else: + # Other patterns use connection context manager + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": True, + }) as conn: + if close_type == "context_manager": + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + elif close_type == "explicit_cursor": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + cursor.close() + + elif close_type == "implicit_fetchall": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py new file mode 100644 index 000000000..d4f6a790a --- /dev/null +++ b/tests/e2e/test_transactions.py @@ -0,0 +1,598 @@ +""" +End-to-end integration tests for Multi-Statement Transaction (MST) APIs. + +These tests verify: +- autocommit property (getter/setter) +- commit() and rollback() methods +- get_transaction_isolation() and set_transaction_isolation() methods +- Transaction error handling + +Requirements: +- DBSQL warehouse that supports Multi-Statement Transactions (MST) +- Test environment configured via test.env file or environment variables + +Setup: +Set the following environment variables: +- DATABRICKS_SERVER_HOSTNAME +- DATABRICKS_HTTP_PATH +- DATABRICKS_ACCESS_TOKEN (or use OAuth) + +Usage: + pytest tests/e2e/test_transactions.py -v +""" + +import logging +import os +import pytest +from typing import Any, Dict + +import databricks.sql as sql +from databricks.sql import TransactionError, NotSupportedError, InterfaceError + +logger = logging.getLogger(__name__) + + +@pytest.mark.skip( + reason="Test environment does not yet support multi-statement transactions" +) +class TestTransactions: + """E2E tests for transaction control methods (MST support).""" + + # Test table name + TEST_TABLE_NAME = "transaction_test_table" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, connection_details): + """Setup test environment before each test and cleanup after.""" + self.connection_params = { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + "ignore_transactions": False, # Enable actual transaction functionality for these tests + } + + # Get catalog and schema from environment or use defaults + self.catalog = os.getenv("DATABRICKS_CATALOG", "main") + self.schema = os.getenv("DATABRICKS_SCHEMA", "default") + + # Create connection for setup + self.connection = sql.connect(**self.connection_params) + + # Setup: Create test table + self._create_test_table() + + yield + + # Teardown: Cleanup + self._cleanup() + + def _get_fully_qualified_table_name(self) -> str: + """Get the fully qualified table name.""" + return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" + + def _create_test_table(self): + """Create the test table with Delta format and MST support.""" + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + + try: + # Drop if exists + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + + # Create table with Delta and catalog-owned feature for MST compatibility + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table_name} + (id INT, value STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + + logger.info(f"Created test table: {fq_table_name}") + finally: + cursor.close() + + def _cleanup(self): + """Cleanup after test: rollback pending transactions, drop table, close connection.""" + try: + # Try to rollback any pending transaction + if ( + self.connection + and self.connection.open + and not self.connection.autocommit + ): + try: + self.connection.rollback() + except Exception as e: + logger.debug( + f"Rollback during cleanup failed (may be expected): {e}" + ) + + # Reset to autocommit mode + try: + self.connection.autocommit = True + except Exception as e: + logger.debug(f"Reset autocommit during cleanup failed: {e}") + + # Drop test table + if self.connection and self.connection.open: + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + try: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + logger.info(f"Dropped test table: {fq_table_name}") + except Exception as e: + logger.warning(f"Failed to drop test table: {e}") + finally: + cursor.close() + + finally: + # Close connection + if self.connection: + self.connection.close() + + # ==================== BASIC AUTOCOMMIT TESTS ==================== + + def test_default_autocommit_is_true(self): + """Test that new connection defaults to autocommit=true.""" + assert ( + self.connection.autocommit is True + ), "New connection should have autocommit=true by default" + + def test_set_autocommit_to_false(self): + """Test successfully setting autocommit to false.""" + self.connection.autocommit = False + assert ( + self.connection.autocommit is False + ), "autocommit should be false after setting to false" + + def test_set_autocommit_to_true(self): + """Test successfully setting autocommit back to true.""" + # First disable + self.connection.autocommit = False + assert self.connection.autocommit is False + + # Then enable + self.connection.autocommit = True + assert ( + self.connection.autocommit is True + ), "autocommit should be true after setting to true" + + # ==================== COMMIT TESTS ==================== + + def test_commit_single_insert(self): + """Test successfully committing a transaction with single INSERT.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" + ) + cursor.close() + + # Commit + self.connection.commit() + + # Verify data is persisted using a new connection + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result is not None, "Should find inserted row after commit" + assert result[0] == "test_value", "Value should match inserted value" + finally: + verify_conn.close() + + def test_commit_multiple_inserts(self): + """Test successfully committing a transaction with multiple INSERTs.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert multiple rows + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") + cursor.close() + + self.connection.commit() + + # Verify all rows persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 3, "Should have 3 rows after commit" + finally: + verify_conn.close() + + # ==================== ROLLBACK TESTS ==================== + + def test_rollback_single_insert(self): + """Test successfully rolling back a transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" + ) + cursor.close() + + # Rollback + self.connection.rollback() + + # Verify data is NOT persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + ) + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 0, "Rolled back data should not be persisted" + finally: + verify_conn.close() + + # ==================== SEQUENTIAL TRANSACTION TESTS ==================== + + def test_multiple_sequential_transactions(self): + """Test executing multiple sequential transactions (commit, commit, rollback).""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") + cursor.close() + self.connection.commit() + + # Second transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") + cursor.close() + self.connection.commit() + + # Third transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") + cursor.close() + self.connection.rollback() + + # Verify only first two transactions persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" + ) + result = verify_cursor.fetchone() + assert result[0] == 2, "Should have 2 committed rows" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") + result = verify_cursor.fetchone() + assert result[0] == 0, "Rolled back row should not exist" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_commit(self): + """Test that new transaction automatically starts after commit.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.commit() + + # New transaction should start automatically - insert and rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.rollback() + + # Verify: first committed, second rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 1, "First insert should be committed" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 0, "Second insert should be rolled back" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_rollback(self): + """Test that new transaction automatically starts after rollback.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.rollback() + + # New transaction should start automatically - insert and commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.commit() + + # Verify: first rolled back, second committed + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 0, "First insert should be rolled back" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 1, "Second insert should be committed" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== UPDATE/DELETE OPERATION TESTS ==================== + + def test_update_in_transaction(self): + """Test UPDATE operation in transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + # First insert a row with autocommit + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + ) + cursor.close() + + # Start transaction and update + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") + cursor.close() + self.connection.commit() + + # Verify update persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == "updated", "Value should be updated after commit" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== MULTI-TABLE TRANSACTION TESTS ==================== + + def test_multi_table_transaction_commit(self): + """Test atomic commit across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" + ) + cursor.close() + + # Commit both atomically + self.connection.commit() + + # Verify both inserts persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table1 insert should be committed" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table2 insert should be committed" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + def test_multi_table_transaction_rollback(self): + """Test atomic rollback across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" + ) + cursor.close() + + # Rollback both atomically + self.connection.rollback() + + # Verify both inserts were rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table1 insert should be rolled back" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table2 insert should be rolled back" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + # ==================== ERROR HANDLING TESTS ==================== + + def test_set_autocommit_during_active_transaction(self): + """Test that setting autocommit during an active transaction throws error.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") + cursor.close() + + # Try to set autocommit=True during active transaction + with pytest.raises(TransactionError) as exc_info: + self.connection.autocommit = True + + # Verify error message mentions autocommit or active transaction + error_msg = str(exc_info.value).lower() + assert ( + "autocommit" in error_msg or "active transaction" in error_msg + ), "Error should mention autocommit or active transaction" + + # Cleanup - rollback the transaction + self.connection.rollback() + + def test_commit_without_active_transaction_throws_error(self): + """Test that commit() throws error when autocommit=true (no active transaction).""" + # Ensure autocommit is true (default) + assert self.connection.autocommit is True + + # Attempt commit without active transaction should throw + with pytest.raises(TransactionError) as exc_info: + self.connection.commit() + + # Verify error message indicates no active transaction + error_message = str(exc_info.value) + assert ( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message + or "no active transaction" in error_message.lower() + ), "Error should indicate no active transaction" + + def test_rollback_without_active_transaction_is_safe(self): + """Test that rollback() without active transaction is a safe no-op.""" + # With autocommit=true (no active transaction) + assert self.connection.autocommit is True + + # ROLLBACK should be safe (no exception) + self.connection.rollback() + + # Verify connection is still usable + assert self.connection.autocommit is True + assert self.connection.open is True + + # ==================== TRANSACTION ISOLATION TESTS ==================== + + def test_get_transaction_isolation_returns_repeatable_read(self): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + isolation_level = self.connection.get_transaction_isolation() + assert ( + isolation_level == "REPEATABLE_READ" + ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + + def test_set_transaction_isolation_accepts_repeatable_read(self): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + # Should not raise - these are all valid formats + self.connection.set_transaction_isolation("REPEATABLE_READ") + self.connection.set_transaction_isolation("REPEATABLE READ") + self.connection.set_transaction_isolation("repeatable_read") + self.connection.set_transaction_isolation("repeatable read") + + def test_set_transaction_isolation_rejects_unsupported_level(self): + """Test that set_transaction_isolation() rejects unsupported levels.""" + with pytest.raises(NotSupportedError) as exc_info: + self.connection.set_transaction_isolation("READ_COMMITTED") + + error_message = str(exc_info.value) + assert "not supported" in error_message.lower() + assert "READ_COMMITTED" in error_message diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py new file mode 100644 index 000000000..432ca1be3 --- /dev/null +++ b/tests/unit/test_circuit_breaker_http_client.py @@ -0,0 +1,208 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open - should raise CircuitBreakerError.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Circuit breaker open should raise (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs - should raise original exception.""" + # Mock delegate to raise a different error (not rate limiting) + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Non-rate-limit errors are unwrapped and raised + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker errors are raised (no longer silent).""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Should raise CircuitBreakerError (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_other_error_logging(self): + """Test that other errors are wrapped, logged, then unwrapped and raised.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Should raise the original ValueError + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged (for wrapping and/or unwrapping) + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + from databricks.sql.exc import TelemetryRateLimitError + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures (429) + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + # All calls should raise TelemetryRateLimitError + # After MINIMUM_CALLS failures, circuit breaker opens + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + # Should have some rate limit errors before circuit opens, then circuit breaker errors + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures first (429) + from databricks.sql.exc import TelemetryRateLimitError + from pybreaker import CircuitBreakerError + + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + # Trigger enough rate limit failures to open circuit + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass # Expected - circuit breaker opens after MINIMUM_CALLS failures + + # Circuit should be open now - raises CircuitBreakerError + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate successful calls (200 response) + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + # Should work again with actual success response + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py new file mode 100644 index 000000000..e8ed4e809 --- /dev/null +++ b/tests/unit/test_circuit_breaker_manager.py @@ -0,0 +1,160 @@ +""" +Unit tests for circuit breaker manager functionality. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + NAME_PREFIX as CIRCUIT_BREAKER_NAME, +) +from pybreaker import CircuitBreakerError + + +class TestCircuitBreakerManager: + """Test cases for CircuitBreakerManager.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_get_circuit_breaker_creates_instance(self): + """Test getting circuit breaker creates instance with correct config.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.name == "telemetry-circuit-breaker-test-host" + assert breaker.fail_max == MINIMUM_CALLS + + def test_get_circuit_breaker_same_host_returns_same_instance(self): + """Test that same host returns same circuit breaker instance.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") + breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker1 is breaker2 + + def test_get_circuit_breaker_different_hosts_return_different_instances(self): + """Test that different hosts return different circuit breaker instances.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") + breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") + + assert breaker1 is not breaker2 + assert breaker1.name != breaker2.name + + def test_thread_safety(self): + """Test thread safety of circuit breaker manager.""" + results = [] + + def get_breaker(host): + breaker = CircuitBreakerManager.get_circuit_breaker(host) + results.append(breaker) + + threads = [] + for i in range(10): + thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(results) == 10 + + # All breakers for same host should be same instance + host0_breakers = [b for b in results if b.name.endswith("host0")] + assert all(b is host0_breakers[0] for b in host0_breakers) + + +class TestCircuitBreakerIntegration: + """Integration tests for circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_state_transitions(self): + """Test circuit breaker state transitions from closed to open.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.current_state == "closed" + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold (MINIMUM_CALLS = 20) + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + def test_circuit_breaker_recovery(self): + """Test circuit breaker recovery after failures.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Try successful call to close circuit breaker + def successful_func(): + return "success" + + try: + result = breaker.call(successful_func) + assert result == "success" + except CircuitBreakerError: + pass # Circuit might still be open, acceptable + + assert breaker.current_state in ["closed", "half-open", "open"] + + @pytest.mark.parametrize("old_state,new_state", [ + ("closed", "open"), + ("open", "half-open"), + ("half-open", "closed"), + ("closed", "half-open"), + ]) + def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): + """Test circuit breaker state listener logs all state transitions.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + ) + + listener = CircuitBreakerStateListener() + mock_cb = Mock() + mock_cb.name = "test-breaker" + + mock_old_state = Mock() + mock_old_state.name = old_state + + mock_new_state = Mock() + mock_new_state.name = new_state + + with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: + listener.state_change(mock_cb, mock_old_state, mock_new_state) + mock_logger.info.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 19375cde3..b515756e8 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -22,7 +22,13 @@ import databricks.sql import databricks.sql.client as client -from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql import ( + InterfaceError, + DatabaseError, + Error, + NotSupportedError, + TransactionError, +) from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState @@ -439,11 +445,6 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): "last operation", ) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_commit_a_noop(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - c.commit() - def test_setinputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setinputsizes(1) @@ -452,12 +453,6 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_rollback_not_supported(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - with self.assertRaises(NotSupportedError): - c.rollback() - @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): @@ -639,11 +634,469 @@ def mock_close_normal(): ) +class TransactionTestSuite(unittest.TestCase): + """ + Unit tests for transaction control methods (MST support). + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + def _create_mock_connection(self, mock_session_class): + """Helper to create a mocked connection for transaction tests.""" + # Mock session + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session.get_autocommit.return_value = True + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=False to test actual transaction functionality + conn = client.Connection( + ignore_transactions=False, **self.DUMMY_CONNECTION_ARGS + ) + return conn + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_getter_returns_cached_value(self, mock_session_class): + """Test that autocommit property returns cached session value by default.""" + conn = self._create_mock_connection(mock_session_class) + + # Get autocommit (should use cached value) + result = conn.autocommit + + conn.session.get_autocommit.assert_called_once() + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_executes_sql(self, mock_session_class): + """Test that setting autocommit executes SET AUTOCOMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = False + + # Verify SQL was executed + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = FALSE") + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(False) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_with_true_value(self, mock_session_class): + """Test setting autocommit to True.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = True + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = TRUE") + conn.session.set_autocommit.assert_called_once_with(True) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_wraps_database_error(self, mock_session_class): + """Test that autocommit setter wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertIn("Failed to set autocommit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "set_autocommit") + self.assertEqual(ctx.exception.context["autocommit_value"], False) + + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): + """Test that exception chaining is preserved.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + original_error = DatabaseError( + "Original error", session_id_hex="test-session-id" + ) + mock_cursor.execute.side_effect = original_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertEqual(ctx.exception.__cause__, original_error) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_executes_sql(self, mock_session_class): + """Test that commit() executes COMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.commit() + + mock_cursor.execute.assert_called_once_with("COMMIT") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_wraps_database_error(self, mock_session_class): + """Test that commit() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.commit() + + self.assertIn("Failed to commit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "commit") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that commit() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.commit() + + self.assertIn("Cannot commit on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_executes_sql(self, mock_session_class): + """Test that rollback() executes ROLLBACK command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_wraps_database_error(self, mock_session_class): + """Test that rollback() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "Unexpected rollback error", + context={"sql_state": "HY000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.rollback() + + self.assertIn("Failed to rollback", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "rollback") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that rollback() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.rollback() + + self.assertIn("Cannot rollback on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_returns_repeatable_read( + self, mock_session_class + ): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + result = conn.get_transaction_isolation() + + self.assertEqual(result, "REPEATABLE_READ") + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that get_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.get_transaction_isolation() + + self.assertIn( + "Cannot get transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_accepts_repeatable_read( + self, mock_session_class + ): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + # Should not raise + conn.set_transaction_isolation("REPEATABLE_READ") + conn.set_transaction_isolation("REPEATABLE READ") # With space + conn.set_transaction_isolation("repeatable_read") # Lowercase with underscore + conn.set_transaction_isolation("repeatable read") # Lowercase with space + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_rejects_other_levels(self, mock_session_class): + """Test that set_transaction_isolation() rejects non-REPEATABLE_READ levels.""" + conn = self._create_mock_connection(mock_session_class) + + with self.assertRaises(NotSupportedError) as ctx: + conn.set_transaction_isolation("READ_COMMITTED") + + self.assertIn("not supported", str(ctx.exception)) + self.assertIn("READ_COMMITTED", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that set_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.set_transaction_isolation("REPEATABLE_READ") + + self.assertIn( + "Cannot set transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): + """Test that fetch_autocommit_from_server=True queries server.""" + # Create connection with fetch_autocommit_from_server=True + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="true") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT") + mock_cursor.fetchone.assert_called_once() + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(True) + + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_class): + """Test that fetch_autocommit_from_server correctly parses false value.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="false") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + conn.session.set_autocommit.assert_called_once_with(False) + self.assertFalse(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_class): + """Test that fetch_autocommit_from_server raises error when no result.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, + ignore_transactions=False, + **self.DUMMY_CONNECTION_ARGS, + ) + + mock_cursor = Mock() + mock_cursor.fetchone.return_value = None + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + _ = conn.autocommit + + self.assertIn("No result returned", str(ctx.exception)) + mock_cursor.close.assert_called_once() + + conn.close() + + # ==================== IGNORE_TRANSACTIONS TESTS ==================== + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): + """Test that commit() is a no-op when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + # Call commit - should be no-op + conn.commit() + + # Verify that execute was NOT called (no-op) + mock_cursor.execute.assert_not_called() + mock_cursor.close.assert_not_called() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_raises_not_supported_when_ignore_transactions_true( + self, mock_session_class + ): + """Test that rollback() raises NotSupportedError when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + # Call rollback - should raise NotSupportedError + with self.assertRaises(NotSupportedError) as ctx: + conn.rollback() + + self.assertIn("Transactions are not supported", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_is_noop_when_ignore_transactions_true( + self, mock_session_class + ): + """Test that autocommit setter is a no-op when ignore_transactions=True.""" + + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + # Create connection with ignore_transactions=True (default) + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + + # Verify ignore_transactions is True by default + self.assertTrue(conn.ignore_transactions) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + # Set autocommit - should be no-op + conn.autocommit = False + + # Verify that execute was NOT called (no-op) + mock_cursor.execute.assert_not_called() + mock_cursor.close.assert_not_called() + + # Session set_autocommit should also not be called + conn.session.set_autocommit.assert_not_called() + + conn.close() + + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() test_classes = [ ClientTestSuite, + TransactionTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite, diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 2ff82cee5..96a2f87d8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,6 +2,7 @@ import pytest from unittest.mock import patch, MagicMock import json +from dataclasses import asdict from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -9,7 +10,20 @@ TelemetryClientFactory, TelemetryHelper, ) -from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow +from databricks.sql.common.feature_flag import ( + FeatureFlagsContextFactory, + FeatureFlagsContext, +) +from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType +from databricks.sql.telemetry.models.event import ( + TelemetryEvent, + DriverConnectionParameters, + DriverSystemConfiguration, + SqlExecutionEvent, + DriverErrorInfo, + DriverVolumeOperation, + HostDetails, +) from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, DatabricksOAuthProvider, @@ -27,7 +41,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -70,12 +86,12 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() - assert len(client._events_batch) == 2 + assert client._events_queue.qsize() == 2 # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() - assert len(client._events_batch) == 0 # Batch cleared after flush + assert client._events_queue.qsize() == 0 # Queue cleared after flush @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_network_request_flow(self, mock_http_request, mock_telemetry_client): @@ -85,7 +101,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -221,7 +237,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -289,7 +307,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -372,8 +392,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -400,8 +422,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -428,8 +452,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -446,3 +472,416 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) + + +class TestTelemetryEventModels: + """Tests for telemetry event model data structures and JSON serialization.""" + + def test_host_details_serialization(self): + """Test HostDetails model serialization.""" + host = HostDetails(host_url="test-host.com", port=443) + + # Test JSON string generation + json_str = host.to_json() + assert isinstance(json_str, str) + parsed = json.loads(json_str) + assert parsed["host_url"] == "test-host.com" + assert parsed["port"] == 443 + + def test_driver_connection_parameters_all_fields(self): + """Test DriverConnectionParameters with all fields populated.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + proxy_info = HostDetails(host_url="proxy.company.com", port=8080) + cf_proxy_info = HostDetails(host_url="cf-proxy.company.com", port=8080) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + auth_mech=AuthMech.OAUTH, + auth_flow=AuthFlow.BROWSER_BASED_AUTHENTICATION, + socket_timeout=30000, + azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", + azure_tenant_id="tenant-123", + use_proxy=True, + use_system_proxy=True, + proxy_host_info=proxy_info, + use_cf_proxy=False, + cf_proxy_host_info=cf_proxy_info, + non_proxy_hosts=["localhost", "127.0.0.1"], + allow_self_signed_support=False, + use_system_trust_store=True, + enable_arrow=True, + enable_direct_results=True, + enable_sea_hybrid_results=True, + http_connection_pool_size=100, + rows_fetched_per_block=100000, + async_poll_interval_millis=2000, + support_many_parameters=True, + enable_complex_datatype_support=True, + allowed_volume_ingestion_paths="/Volumes/catalog/schema/volume", + ) + + # Serialize to JSON and parse back + json_str = params.to_json() + json_dict = json.loads(json_str) + + # Verify all new fields are in JSON + assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" + assert json_dict["mode"] == "SEA" + assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" + assert json_dict["auth_mech"] == "OAUTH" + assert json_dict["auth_flow"] == "BROWSER_BASED_AUTHENTICATION" + assert json_dict["socket_timeout"] == 30000 + assert json_dict["azure_workspace_resource_id"] == "/subscriptions/test/resourceGroups/test" + assert json_dict["azure_tenant_id"] == "tenant-123" + assert json_dict["use_proxy"] is True + assert json_dict["use_system_proxy"] is True + assert json_dict["proxy_host_info"]["host_url"] == "proxy.company.com" + assert json_dict["use_cf_proxy"] is False + assert json_dict["cf_proxy_host_info"]["host_url"] == "cf-proxy.company.com" + assert json_dict["non_proxy_hosts"] == ["localhost", "127.0.0.1"] + assert json_dict["allow_self_signed_support"] is False + assert json_dict["use_system_trust_store"] is True + assert json_dict["enable_arrow"] is True + assert json_dict["enable_direct_results"] is True + assert json_dict["enable_sea_hybrid_results"] is True + assert json_dict["http_connection_pool_size"] == 100 + assert json_dict["rows_fetched_per_block"] == 100000 + assert json_dict["async_poll_interval_millis"] == 2000 + assert json_dict["support_many_parameters"] is True + assert json_dict["enable_complex_datatype_support"] is True + assert json_dict["allowed_volume_ingestion_paths"] == "/Volumes/catalog/schema/volume" + + def test_driver_connection_parameters_minimal_fields(self): + """Test DriverConnectionParameters with only required fields.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.THRIFT, + host_info=host_info, + ) + + # Note: to_json() filters out None values, so we need to check asdict for complete structure + json_str = params.to_json() + json_dict = json.loads(json_str) + + # Required fields should be present + assert json_dict["http_path"] == "/sql/1.0/warehouses/abc123" + assert json_dict["mode"] == "THRIFT" + assert json_dict["host_info"]["host_url"] == "workspace.databricks.com" + + # Optional fields with None are filtered out by to_json() + # This is expected behavior - None values are excluded from JSON output + + def test_driver_system_configuration_serialization(self): + """Test DriverSystemConfiguration model serialization.""" + sys_config = DriverSystemConfiguration( + driver_name="Databricks SQL Connector for Python", + driver_version="3.0.0", + runtime_name="CPython", + runtime_version="3.11.0", + runtime_vendor="Python Software Foundation", + os_name="Darwin", + os_version="23.0.0", + os_arch="arm64", + char_set_encoding="utf-8", + locale_name="en_US", + client_app_name="MyApp", + ) + + json_str = sys_config.to_json() + json_dict = json.loads(json_str) + + assert json_dict["driver_name"] == "Databricks SQL Connector for Python" + assert json_dict["driver_version"] == "3.0.0" + assert json_dict["runtime_name"] == "CPython" + assert json_dict["runtime_version"] == "3.11.0" + assert json_dict["runtime_vendor"] == "Python Software Foundation" + assert json_dict["os_name"] == "Darwin" + assert json_dict["os_version"] == "23.0.0" + assert json_dict["os_arch"] == "arm64" + assert json_dict["locale_name"] == "en_US" + assert json_dict["char_set_encoding"] == "utf-8" + assert json_dict["client_app_name"] == "MyApp" + + def test_telemetry_event_complete_serialization(self): + """Test complete TelemetryEvent serialization with all nested objects.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + proxy_info = HostDetails(host_url="proxy.company.com", port=8080) + + connection_params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + auth_mech=AuthMech.OAUTH, + use_proxy=True, + proxy_host_info=proxy_info, + enable_arrow=True, + rows_fetched_per_block=100000, + ) + + sys_config = DriverSystemConfiguration( + driver_name="Databricks SQL Connector for Python", + driver_version="3.0.0", + runtime_name="CPython", + runtime_version="3.11.0", + runtime_vendor="Python Software Foundation", + os_name="Darwin", + os_version="23.0.0", + os_arch="arm64", + char_set_encoding="utf-8", + ) + + error_info = DriverErrorInfo( + error_name="ConnectionError", + stack_trace="Traceback...", + ) + + event = TelemetryEvent( + session_id="test-session-123", + sql_statement_id="test-stmt-456", + operation_latency_ms=1500, + auth_type="OAUTH", + system_configuration=sys_config, + driver_connection_params=connection_params, + error_info=error_info, + ) + + # Test JSON serialization + json_str = event.to_json() + assert isinstance(json_str, str) + + # Parse and verify structure + parsed = json.loads(json_str) + assert parsed["session_id"] == "test-session-123" + assert parsed["sql_statement_id"] == "test-stmt-456" + assert parsed["operation_latency_ms"] == 1500 + assert parsed["auth_type"] == "OAUTH" + + # Verify nested objects + assert parsed["system_configuration"]["driver_name"] == "Databricks SQL Connector for Python" + assert parsed["driver_connection_params"]["http_path"] == "/sql/1.0/warehouses/abc123" + assert parsed["driver_connection_params"]["use_proxy"] is True + assert parsed["driver_connection_params"]["proxy_host_info"]["host_url"] == "proxy.company.com" + assert parsed["error_info"]["error_name"] == "ConnectionError" + + def test_json_serialization_excludes_none_values(self): + """Test that JSON serialization properly excludes None values.""" + host_info = HostDetails(host_url="workspace.databricks.com", port=443) + + params = DriverConnectionParameters( + http_path="/sql/1.0/warehouses/abc123", + mode=DatabricksClientType.SEA, + host_info=host_info, + # All optional fields left as None + ) + + json_str = params.to_json() + parsed = json.loads(json_str) + + # Required fields present + assert parsed["http_path"] == "/sql/1.0/warehouses/abc123" + + # None values should be EXCLUDED from JSON (not included as null) + # This is the behavior of JsonSerializableMixin + assert "auth_mech" not in parsed + assert "azure_tenant_id" not in parsed + assert "proxy_host_info" not in parsed + + +@patch("databricks.sql.client.Session") +@patch("databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers") +class TestConnectionParameterTelemetry: + """Tests for connection parameter population in telemetry.""" + + def test_connection_with_proxy_populates_telemetry(self, mock_setup_pools, mock_session): + """Test that proxy configuration is captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-proxy" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + # Verify export was called + mock_export.assert_called_once() + call_args = mock_export.call_args + + # Extract driver_connection_params + driver_params = call_args.kwargs.get("driver_connection_params") + assert driver_params is not None + assert isinstance(driver_params, DriverConnectionParameters) + + # Verify fields are populated + assert driver_params.http_path == "/sql/1.0/warehouses/test" + assert driver_params.mode == DatabricksClientType.SEA + assert driver_params.host_info.host_url == "workspace.databricks.com" + assert driver_params.host_info.port == 443 + + def test_connection_with_azure_params_populates_telemetry(self, mock_setup_pools, mock_session): + """Test that Azure-specific parameters are captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-azure" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = False + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.azuredatabricks.net" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.azuredatabricks.net", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + azure_workspace_resource_id="/subscriptions/test/resourceGroups/test", + azure_tenant_id="tenant-123", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # Verify Azure fields + assert driver_params.azure_workspace_resource_id == "/subscriptions/test/resourceGroups/test" + assert driver_params.azure_tenant_id == "tenant-123" + + def test_connection_populates_arrow_and_performance_params(self, mock_setup_pools, mock_session): + """Test that Arrow and performance parameters are captured in telemetry.""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-perf" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + # Import pyarrow availability check + try: + import pyarrow + arrow_available = True + except ImportError: + arrow_available = False + + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + pool_maxsize=200, + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # Verify performance fields + assert driver_params.enable_arrow == arrow_available + assert driver_params.enable_direct_results is True + assert driver_params.http_connection_pool_size == 200 + assert driver_params.rows_fetched_per_block == 100000 # DEFAULT_ARRAY_SIZE + assert driver_params.async_poll_interval_millis == 2000 + assert driver_params.support_many_parameters is True + + def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_session): + """Test that CloudFlare proxy fields default to False/None (not yet supported).""" + mock_session_instance = MagicMock() + mock_session_instance.guid_hex = "test-session-cfproxy" + mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + mock_session_instance.is_open = False + mock_session_instance.use_sea = True + mock_session_instance.port = 443 + mock_session_instance.host = "workspace.databricks.com" + mock_session.return_value = mock_session_instance + + with patch("databricks.sql.telemetry.telemetry_client.TelemetryClient.export_initial_telemetry_log") as mock_export: + conn = sql.connect( + server_hostname="workspace.databricks.com", + http_path="/sql/1.0/warehouses/test", + access_token="test-token", + enable_telemetry=True, + force_enable_telemetry=True, + ) + + mock_export.assert_called_once() + driver_params = mock_export.call_args.kwargs.get("driver_connection_params") + + # CF proxy not yet supported - should be False/None + assert driver_params.use_cf_proxy is False + assert driver_params.cf_proxy_host_info is None + + +class TestFeatureFlagsContextFactory: + """Tests for FeatureFlagsContextFactory host-level caching.""" + + @pytest.fixture(autouse=True) + def reset_factory(self): + """Reset factory state before/after each test.""" + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + yield + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + + @pytest.mark.parametrize( + "hosts,expected_contexts", + [ + (["host1.com", "host1.com"], 1), # Same host shares context + (["host1.com", "host2.com"], 2), # Different hosts get separate contexts + (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario + ], + ) + def test_host_level_caching(self, hosts, expected_contexts): + """Test that contexts are cached by host correctly.""" + contexts = [] + for host in hosts: + conn = MagicMock() + conn.session.host = host + conn.session.http_client = MagicMock() + contexts.append(FeatureFlagsContextFactory.get_instance(conn)) + + assert len(FeatureFlagsContextFactory._context_map) == expected_contexts + if expected_contexts == 1: + assert all(ctx is contexts[0] for ctx in contexts) + + def test_remove_instance_and_executor_cleanup(self): + """Test removal uses host key and cleans up executor when empty.""" + conn1 = MagicMock() + conn1.session.host = "host1.com" + conn1.session.http_client = MagicMock() + + conn2 = MagicMock() + conn2.session.host = "host2.com" + conn2.session.http_client = MagicMock() + + FeatureFlagsContextFactory.get_instance(conn1) + FeatureFlagsContextFactory.get_instance(conn2) + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn1) + assert len(FeatureFlagsContextFactory._context_map) == 1 + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn2) + assert len(FeatureFlagsContextFactory._context_map) == 0 + assert FeatureFlagsContextFactory._executor is None diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py new file mode 100644 index 000000000..0e9455e1f --- /dev/null +++ b/tests/unit/test_telemetry_push_client.py @@ -0,0 +1,213 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import TelemetryRateLimitError +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_circuit_breaker_open(self): + """Test request when circuit breaker is open raises CircuitBreakerError.""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_other_error(self): + """Test request when other error occurs raises original exception.""" + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code,expected_error", [ + (429, TelemetryRateLimitError), + (503, TelemetryRateLimitError), + ]) + def test_request_rate_limit_codes(self, status_code, expected_error): + """Test that rate-limit status codes raise TelemetryRateLimitError.""" + mock_response = Mock() + mock_response.status = status_code + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(expected_error): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_non_rate_limit_code(self): + """Test that non-rate-limit status codes return response.""" + mock_response = Mock() + mock_response.status = 500 + mock_response.data = b'Server error' + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 500 + + def test_rate_limit_error_logging(self): + """Test that rate limit errors are logged with circuit breaker context.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(TelemetryRateLimitError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "429" in str(warning_args) + assert "circuit breaker" in warning_args[0] + + def test_other_error_logging(self): + """Test that other errors are logged during wrapping/unwrapping.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + import time + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Trigger failures + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass + + # Circuit should be open + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate success + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 + + def test_urllib3_import_fallback(self): + """Test that the urllib3 import fallback works correctly.""" + from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py new file mode 100644 index 000000000..aa31f6628 --- /dev/null +++ b/tests/unit/test_telemetry_request_error_handling.py @@ -0,0 +1,96 @@ +""" +Unit tests specifically for telemetry_push_client RequestError handling +with http-code context extraction for rate limiting detection. +""" + +import pytest +from unittest.mock import Mock + +from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + TelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError, TelemetryRateLimitError +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +class TestTelemetryPushClientRequestErrorHandling: + """Test RequestError handling and http-code context extraction.""" + + @pytest.fixture + def setup_circuit_breaker(self): + """Setup circuit breaker for testing.""" + CircuitBreakerManager._instances.clear() + yield + CircuitBreakerManager._instances.clear() + + @pytest.fixture + def mock_delegate(self): + """Create mock delegate client.""" + return Mock(spec=TelemetryPushClient) + + @pytest.fixture + def client(self, mock_delegate, setup_circuit_breaker): + """Create CircuitBreakerTelemetryPushClient instance.""" + return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") + + @pytest.mark.parametrize("status_code", [429, 503]) + def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code", [500, 400, 404]) + def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with non-rate-limit codes raises original RequestError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("context", [{}, None, "429"]) + def test_request_error_with_invalid_context(self, client, mock_delegate, context): + """Test RequestError with invalid/missing context raises original error.""" + request_error = RequestError("HTTP request failed") + if context == "429": + # Edge case: http-code as string instead of int + request_error.context = {"http-code": context} + else: + request_error.context = context + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_error_missing_context_attribute(self, client, mock_delegate): + """Test RequestError without context attribute raises original error.""" + request_error = RequestError("HTTP request failed") + if hasattr(request_error, "context"): + delattr(request_error, "context") + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_http_code_extraction_prioritization(self, client, mock_delegate): + """Test that http-code from RequestError context is correctly extracted.""" + request_error = RequestError( + "HTTP request failed after retries", context={"http-code": 503} + ) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_non_request_error_exceptions_raised(self, client, mock_delegate): + """Test that non-RequestError exceptions are wrapped then unwrapped.""" + generic_error = ValueError("Network timeout") + mock_delegate.request.side_effect = generic_error + + with pytest.raises(ValueError, match="Network timeout"): + client.request(HttpMethod.POST, "https://test.com", {}) diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py new file mode 100644 index 000000000..4e9ce1bbf --- /dev/null +++ b/tests/unit/test_unified_http_client.py @@ -0,0 +1,136 @@ +""" +Unit tests for UnifiedHttpClient, specifically testing MaxRetryError handling +and HTTP status code extraction. +""" + +import pytest +from unittest.mock import Mock, patch +from urllib3.exceptions import MaxRetryError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError +from databricks.sql.auth.common import ClientContext +from databricks.sql.types import SSLOptions + + +class TestUnifiedHttpClientMaxRetryError: + """Test MaxRetryError handling and HTTP status code extraction.""" + + @pytest.fixture + def client_context(self): + """Create a minimal ClientContext for testing.""" + context = Mock(spec=ClientContext) + context.hostname = "https://test.databricks.com" + context.ssl_options = SSLOptions( + tls_verify=True, + tls_verify_hostname=True, + tls_trusted_ca_file=None, + tls_client_cert_file=None, + tls_client_cert_key_file=None, + tls_client_cert_key_password=None, + ) + context.socket_timeout = 30 + context.retry_stop_after_attempts_count = 3 + context.retry_delay_min = 1.0 + context.retry_delay_max = 10.0 + context.retry_stop_after_attempts_duration = 300.0 + context.retry_delay_default = 5.0 + context.retry_dangerous_codes = [] + context.proxy_auth_method = None + context.pool_connections = 10 + context.pool_maxsize = 20 + context.user_agent = "test-agent" + return context + + @pytest.fixture + def http_client(self, client_context): + """Create UnifiedHttpClient instance.""" + return UnifiedHttpClient(client_context) + + @pytest.mark.parametrize("status_code,path", [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ]) + def test_max_retry_error_with_status_codes(self, http_client, status_code, path): + """Test MaxRetryError with various status codes and response paths.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + if path == "reason.response": + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = status_code + else: # direct_response + max_retry_error.response = Mock() + max_retry_error.response.status = status_code + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request( + HttpMethod.POST, "http://test.com", headers={"test": "header"} + ) + + error = exc_info.value + assert hasattr(error, "context") + assert "http-code" in error.context + assert error.context["http-code"] == status_code + + @pytest.mark.parametrize("setup_func", [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr + ]) + def test_max_retry_error_missing_status(self, http_client, setup_func): + """Test MaxRetryError without status code (no crash, empty context).""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + setup_func(max_retry_error) + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context == {} + + def test_max_retry_error_prefers_reason_response(self, http_client): + """Test that e.reason.response.status is preferred over e.response.status.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + # Set both structures with different status codes + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = 429 # Should use this + + max_retry_error.response = Mock() + max_retry_error.response.status = 500 # Should be ignored + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context["http-code"] == 429 + + def test_generic_exception_no_crash(self, http_client): + """Test that generic exceptions don't crash when checking for status code.""" + generic_error = Exception("Network error") + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=generic_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.POST, "http://test.com") + + error = exc_info.value + assert "HTTP request error" in str(error) From f7af379316c4c1d86fc76cc59a92f512a52c0df6 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Mon, 1 Dec 2025 01:25:08 +0530 Subject: [PATCH 33/35] simplify diff --- src/databricks/sql/backend/sea/backend.py | 4 ++-- .../sql/telemetry/telemetry_client.py | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 9c5c63033..bd41fe293 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -287,9 +287,9 @@ def close_session(self, session_id: SessionId) -> None: logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) - sea_session_id = session_id.to_sea_session_id() - if sea_session_id is None: + if session_id.backend_type != BackendType.SEA: raise ValueError("Not a valid SEA session ID") + sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( warehouse_id=self.warehouse_id, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 892485a4a..8e5860b65 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -190,6 +190,7 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None + self._lock = threading.RLock() self._pending_futures = set() # OPTIMIZATION: Use lock-free Queue instead of list + lock @@ -443,9 +444,9 @@ class TelemetryClientFactory: It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ - _clients: Dict[str, BaseTelemetryClient] = ( - {} - ) # Map of session_id_hex -> BaseTelemetryClient + _clients: Dict[ + str, BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -546,8 +547,9 @@ def initialize_telemetry_client( session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[session_id_hex] = ( - TelemetryClient( + TelemetryClientFactory._clients[ + session_id_hex + ] = TelemetryClient( telemetry_enabled=telemetry_enabled, session_id_hex=session_id_hex, auth_provider=auth_provider, @@ -556,11 +558,10 @@ def initialize_telemetry_client( batch_size=batch_size, client_context=client_context, ) - ) else: - TelemetryClientFactory._clients[session_id_hex] = ( - NoopTelemetryClient() - ) + TelemetryClientFactory._clients[ + session_id_hex + ] = NoopTelemetryClient() except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail From de2f848a2a803a40879e722d4f76b2f2d682b800 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Mon, 1 Dec 2025 01:26:14 +0530 Subject: [PATCH 34/35] black src --- .../sql/telemetry/telemetry_client.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 8e5860b65..acf38d375 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -446,7 +446,7 @@ class TelemetryClientFactory: _clients: Dict[ str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + ] = {} # Map of session_id_hex -> BaseTelemetryClient _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -550,14 +550,14 @@ def initialize_telemetry_client( TelemetryClientFactory._clients[ session_id_hex ] = TelemetryClient( - telemetry_enabled=telemetry_enabled, - session_id_hex=session_id_hex, - auth_provider=auth_provider, - host_url=host_url, - executor=TelemetryClientFactory._executor, - batch_size=batch_size, - client_context=client_context, - ) + telemetry_enabled=telemetry_enabled, + session_id_hex=session_id_hex, + auth_provider=auth_provider, + host_url=host_url, + executor=TelemetryClientFactory._executor, + batch_size=batch_size, + client_context=client_context, + ) else: TelemetryClientFactory._clients[ session_id_hex From ec991fe20d3764df60863c9c41fe6488e34a1906 Mon Sep 17 00:00:00 2001 From: Varun0157 Date: Mon, 1 Dec 2025 01:28:11 +0530 Subject: [PATCH 35/35] remove excess ThreadPoolExcecutor import --- src/databricks/sql/telemetry/telemetry_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index acf38d375..b5e279c42 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,7 +2,7 @@ import time import logging import json -from concurrent.futures import ThreadPoolExecutor, wait +from concurrent.futures import wait from queue import Queue, Full from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future