diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2a2ac0e6b539..a5cc4fb3dd19 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,6 +727,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] + self.global_user_context_extensions_lock = threading.Lock() + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + self._update_request_with_user_context_extensions(req) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1807,6 +1831,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1945,6 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe4..39abc769b673 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -26,6 +26,7 @@ if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd