From e94e9f68b21f802f6072143213e4b78d08a6e146 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 16:45:14 -0800 Subject: [PATCH 1/8] core --- python/pyspark/sql/connect/client/core.py | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2a2ac0e6b539..a5cc4fb3dd19 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,6 +727,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] + self.global_user_context_extensions_lock = threading.Lock() + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + self._update_request_with_user_context_extensions(req) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1807,6 +1831,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1945,6 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) From 878d56f28df732df78fb54d0b5bf06d9646d339f Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 22:42:27 -0800 Subject: [PATCH 2/8] add test --- .../sql/tests/connect/client/test_client.py | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe4..b9db18f114ce 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -20,12 +20,14 @@ from collections.abc import Generator from typing import Optional, Any, Union +import google.protobuf.any_pb2 as any_pb2 +import google.protobuf.wrappers_pb2 as wrappers_pb2 + from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually if should_test_connect: import grpc - import google.protobuf.any_pb2 as any_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd @@ -136,9 +138,11 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None + self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -159,12 +163,25 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + # Always returns SemanticHash message + self.req = req + self.client_user_context_extensions = req.user_context.extensions + self._update_client_call_context() + resp = proto.AnalyzePlanResponse( + session_id=self._session_id, + semantic_hash=proto.AnalyzePlanResponse.SemanticHash(result=0), + ) + return resp + def Config(self, req: proto.ConfigRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -229,6 +246,93 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") + def test_user_context_extension(self): + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) From 8675ed22b90e6c845d53889bd4aa8ef4a3aabe38 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 22:50:22 -0800 Subject: [PATCH 3/8] nit --- python/pyspark/sql/tests/connect/client/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index b9db18f114ce..c4c1d0586cb7 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -172,7 +172,6 @@ def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): # Always returns SemanticHash message self.req = req self.client_user_context_extensions = req.user_context.extensions - self._update_client_call_context() resp = proto.AnalyzePlanResponse( session_id=self._session_id, semantic_hash=proto.AnalyzePlanResponse.SemanticHash(result=0), From 6a3d7b18e6c424b1c8fedf024d1203c375c9bf2e Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 23:42:01 -0800 Subject: [PATCH 4/8] rm test_user_context_extension --- .../sql/tests/connect/client/test_client.py | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c4c1d0586cb7..75c2e74866db 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -245,93 +245,6 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") - def test_user_context_extension(self): - client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - mock = MockService(client._session_id) - client._stub = mock - - exlocal = any_pb2.Any() - exlocal.Pack(wrappers_pb2.StringValue(value="abc")) - exlocal2 = any_pb2.Any() - exlocal2.Pack(wrappers_pb2.StringValue(value="def")) - exglobal = any_pb2.Any() - exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) - exglobal2 = any_pb2.Any() - exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) - - exlocal_id = client.add_threadlocal_user_context_extension(exlocal) - exglobal_id = client.add_global_user_context_extension(exglobal) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_threadlocal_user_context_extension(exlocal2) - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_global_user_context_extension(exglobal2) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exlocal_id) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exglobal_id) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.clear_user_context_extensions() - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) From a7d3b43ce82fe96be67a5a389dd05b18a6b823a7 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Fri, 14 Nov 2025 07:14:42 -0800 Subject: [PATCH 5/8] nit --- python/pyspark/sql/tests/connect/client/test_client.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 75c2e74866db..72b616339881 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -168,16 +168,6 @@ def Interrupt(self, req: proto.InterruptRequest, metadata): resp.session_id = self._session_id return resp - def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): - # Always returns SemanticHash message - self.req = req - self.client_user_context_extensions = req.user_context.extensions - resp = proto.AnalyzePlanResponse( - session_id=self._session_id, - semantic_hash=proto.AnalyzePlanResponse.SemanticHash(result=0), - ) - return resp - def Config(self, req: proto.ConfigRequest, metadata): self.req = req self.client_user_context_extensions = req.user_context.extensions From 9007710da7e8fe2a5bfb8496c73f1c793cf84131 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Fri, 14 Nov 2025 08:00:00 -0800 Subject: [PATCH 6/8] rm import --- python/pyspark/sql/tests/connect/client/test_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 72b616339881..02453c071339 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -20,14 +20,12 @@ from collections.abc import Generator from typing import Optional, Any, Union -import google.protobuf.any_pb2 as any_pb2 -import google.protobuf.wrappers_pb2 as wrappers_pb2 - from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually if should_test_connect: import grpc + import google.protobuf.any_pb2 as any_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd From 5aae73fd12cdda0fb1923b27ffd66718b9dd01c2 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Fri, 14 Nov 2025 08:01:57 -0800 Subject: [PATCH 7/8] revert --- python/pyspark/sql/tests/connect/client/test_client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 02453c071339..c189f996cbe4 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -136,11 +136,9 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None - self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -161,14 +159,12 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): From 8dd23c247f541d74507a539bb6019f32f5416285 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Fri, 14 Nov 2025 08:08:59 -0800 Subject: [PATCH 8/8] add import --- python/pyspark/sql/tests/connect/client/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe4..39abc769b673 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -26,6 +26,7 @@ if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd