From cc4da754276d172b6eab74a823cacd3d6fccf9f7 Mon Sep 17 00:00:00 2001 From: Susheel Aroskar Date: Thu, 13 Nov 2025 17:31:02 -0800 Subject: [PATCH 1/6] Improve Server-Side debuggability in Spark Connect by capturing client application's file name and line numbers in PySpark --- python/pyspark/sql/connect/client/core.py | 62 +++ .../client/test_client_call_stack_trace.py | 433 ++++++++++++++++++ 2 files changed, 495 insertions(+) create mode 100644 python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 48e07642e157..5334f53239e0 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -22,6 +22,7 @@ import atexit +import pyspark from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -35,6 +36,7 @@ import uuid import sys import time +import traceback from typing import ( Iterable, Iterator, @@ -65,6 +67,8 @@ from pyspark.util import is_remote_only from pyspark.accumulators import SpecialAccumulatorIds from pyspark.version import __version__ +from pyspark import traceback_utils +from pyspark.traceback_utils import CallSite from pyspark.resource.information import ResourceInformation from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager @@ -114,6 +118,9 @@ from pyspark.sql.datasource import DataSource +PYSPARK_ROOT = os.path.dirname(pyspark.__file__) + + def _import_zstandard_if_available() -> Optional[Any]: """ Import zstandard if available, otherwise return None. @@ -605,6 +612,49 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": warnings=list(pb.warnings), ) +def _is_pyspark_source(filename: str) -> bool: + """Check if the given filename is from the pyspark package.""" + return filename.startswith(PYSPARK_ROOT) + + +def _retrieve_stack_frames() -> List[CallSite]: + """ + Return a list of CallSites representing the relevant stack frames in the callstack. + """ + frames = traceback.extract_stack() + + filtered_stack_frames = [] + for i, frame in enumerate(frames): + filename, lineno, func, _ = frame + if _is_pyspark_source(filename): + break + if i+1 < len(frames): + _, _, func, _ = frames[i+1] + filtered_stack_frames.append(CallSite(function=func, file=filename, linenum=lineno)) + + return filtered_stack_frames + +def _build_call_stack_trace() -> List[any_pb2.Any]: + """ + Build a call stack trace for the current Spark Connect action + Returns + ------- + List[any_pb2.Any]: A list of Any objects, each representing a stack frame in the call stack trace in the user code. + """ + call_stack_trace = [] + if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() in ("true", "1"): + stack_frames = _retrieve_stack_frames() + for i, call_site in enumerate(stack_frames): + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + stack_trace_element.declaring_class = "" # unknown information + stack_trace_element.method_name = call_site.function + stack_trace_element.file_name = call_site.file + stack_trace_element.line_number = call_site.linenum + stack_frame = any_pb2.Any() + stack_frame.Pack(stack_trace_element) + call_stack_trace.append(stack_frame) + return call_stack_trace + class SparkConnectClient(object): """ @@ -1280,6 +1330,7 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( self, req: Union[ @@ -1298,6 +1349,7 @@ def _update_request_with_user_context_extensions( for _, extension in self.thread_local.user_context_extensions: req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1329,6 +1381,10 @@ def _execute_plan_request_with_metadata( ) req.operation_id = operation_id self._update_request_with_user_context_extensions(req) + + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.extend(call_stack_trace) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1340,6 +1396,9 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: if self._user_id: req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.extend(call_stack_trace) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1755,6 +1814,9 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: if self._user_id: req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.extend(call_stack_trace) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py new file mode 100644 index 000000000000..ff8e8bc0fbae --- /dev/null +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -0,0 +1,433 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import unittest +from unittest.mock import patch + +import pyspark +from pyspark.sql import functions +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message + +if should_test_connect: + import pyspark.sql.connect.proto as pb2 + from pyspark.sql.connect.client import SparkConnectClient, core + from pyspark.sql.connect.client.core import ( + _is_pyspark_source, + _retrieve_stack_frames, + _build_call_stack_trace, + ) + from pyspark.traceback_utils import CallSite + from google.protobuf import any_pb2 + + + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) + # and it blocks the test process exiting because it is registered as the atexit handler + # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. + SparkConnectClient._cleanup_ml_cache = lambda _: None + + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class CallStackTraceTestCase(unittest.TestCase): + """Test cases for call stack trace functionality in Spark Connect client.""" + + def setUp(self): + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are + # not filtered out. + self.original_pyspark_root = core.PYSPARK_ROOT + core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) + + def tearDown(self): + # Restore the original PYSPARK_ROOT + core.PYSPARK_ROOT = self.original_pyspark_root + + def test_is_pyspark_source_with_pyspark_file(self): + """Test that _is_pyspark_source correctly identifies PySpark files.""" + # Get a known pyspark file path + from pyspark import sql + + pyspark_file = sql.connect.client.__file__ + self.assertTrue(_is_pyspark_source(pyspark_file)) + + def test_is_pyspark_source_with_non_pyspark_file(self): + """Test that _is_pyspark_source correctly identifies non-PySpark files.""" + # Use the current test file which is in pyspark but we'll simulate a non-pyspark path + non_pyspark_file = "/tmp/user_script.py" + self.assertFalse(_is_pyspark_source(non_pyspark_file)) + + # Test with stdlib file + stdlib_file = os.__file__ + self.assertFalse(_is_pyspark_source(stdlib_file)) + + def test_is_pyspark_source_with_relative_path(self): + """Test _is_pyspark_source with various path formats.""" + from pyspark import sql + + # Test with absolute path to pyspark file + pyspark_sql_file = sql.connect.client.__file__ + self.assertTrue(_is_pyspark_source(pyspark_sql_file)) + + # Test with non-pyspark absolute path + self.assertFalse(_is_pyspark_source("/home/user/my_script.py")) + + def test_retrieve_stack_frames_filters_pyspark_frames(self): + """Test that _retrieve_stack_frames filters out PySpark internal frames.""" + def user_function(): + return _retrieve_stack_frames() + + stack_frames = user_function() + + # Verify we have at least some frames + self.assertGreater(len(stack_frames), 0, "Expected at least some stack frames") + + # Verify that none of the returned frames are from PySpark internal code + for frame in stack_frames: + # Check that this frame is not from pyspark internal code + self.assertFalse( + _is_pyspark_source(frame.file), + f"Expected frame from {frame.file} (function: {frame.function}) to be filtered out as PySpark internal frame" + ) + + # Verify that user function names are present (confirming user frames are included) + function_names = [frame.function for frame in stack_frames] + expected_functions = ["user_function", "test_retrieve_stack_frames_filters_pyspark_frames"] + self.assertTrue("user_function" in function_names, f"Expected user function names not found in: {function_names}") + self.assertTrue("test_retrieve_stack_frames_filters_pyspark_frames" in function_names, f"Expected user function names not found in: {function_names}") + + def test_retrieve_stack_frames_includes_user_frames(self): + """Test that _retrieve_stack_frames includes user code frames.""" + def user_function(): + """Simulate a user function.""" + return _retrieve_stack_frames() + + def another_user_function(): + """Another level of user code.""" + return user_function() + + stack_frames = another_user_function() + + # We should have at least some frames from the test + self.assertGreater(len(stack_frames), 0) + + # Check that we have frames with function names we expect + function_names = [frame.function for frame in stack_frames] + # At least one of our test functions should be in the stack + self.assertTrue("user_function" in function_names, f"Expected user function names not found in: {function_names}") + self.assertTrue("another_user_function" in function_names, f"Expected user function names not found in: {function_names}") + self.assertTrue("test_retrieve_stack_frames_includes_user_frames" in function_names, f"Expected user function names not found in: {function_names}") + + def test_retrieve_stack_frames_captures_correct_info(self): + """Test that _retrieve_stack_frames captures correct frame information.""" + def user_function(): + return _retrieve_stack_frames() + + stack_frames = user_function() + + # Verify each frame has the expected attributes + functions = set() + files = set() + for frame in stack_frames: + functions.add(frame.function) + files.add(frame.file) + self.assertIsNotNone(frame.function) + self.assertIsNotNone(frame.file) + self.assertIsNotNone(frame.linenum) + self.assertIsInstance(frame.function, str) + self.assertIsInstance(frame.file, str) + self.assertIsInstance(frame.linenum, int) + self.assertGreater(frame.linenum, 0) + + self.assertTrue("user_function" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue("test_retrieve_stack_frames_captures_correct_info" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_build_call_stack_trace_without_env_var(self): + """Test that _build_call_stack_trace returns empty list when env var is not set.""" + # Make sure the env var is not set + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + call_stack = _build_call_stack_trace() + + self.assertIsInstance(call_stack, list) + self.assertEqual(len(call_stack), 0, "Expected empty list when env var is not set") + + def test_build_call_stack_trace_with_env_var_set(self): + """Test that _build_call_stack_trace builds trace when env var is set.""" + # Set the env var to enable call stack tracing + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + call_stack = _build_call_stack_trace() + + self.assertIsInstance(call_stack, list) + # Should have at least one frame (this test function) + self.assertGreater( + len(call_stack), 0, "Expected non-empty list when env var is set" + ) + + # Verify each element is an Any protobuf message + functions = set() + files = set() + for stack_frame in call_stack: + self.assertIsInstance(stack_frame, any_pb2.Any) + + # Unpack and verify it's a StackTraceElement + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + stack_frame.Unpack(stack_trace_element) + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + + # Verify the fields are populated + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + self.assertIsInstance(stack_trace_element.line_number, int) + self.assertEqual( + stack_trace_element.declaring_class, "", "declaring_class should be empty" + ) + + self.assertTrue("test_build_call_stack_trace_with_env_var_set" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_build_call_stack_trace_with_env_var_empty_string(self): + """Test that _build_call_stack_trace returns empty list when env var is empty string.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": ""}): + call_stack = _build_call_stack_trace() + + self.assertIsInstance(call_stack, list) + self.assertEqual( + len(call_stack), 0, "Expected empty list when env var is empty string" + ) + + def test_build_call_stack_trace_with_various_env_var_values(self): + """Test _build_call_stack_trace behavior with various env var values.""" + test_cases = [ + ("0", 0, "zero string should be treated as falsy"), + ("false", 0, "non-empty string 'false' should be falsy"), + ("TRUE", 1, "string 'TRUE' should be truthy"), + ("true", 1, "string 'true' should be truthy"), + ("1", 1, "string '1' should be truthy"), + ("any_value", 0, "any non-empty string should be falsy"), + ] + + for env_value, expected_behavior, message in test_cases: + with self.subTest(env_value=env_value): + with patch.dict( + os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": env_value} + ): + call_stack = _build_call_stack_trace() + if expected_behavior == 0: + self.assertEqual(len(call_stack), 0, message) + else: + self.assertGreater(len(call_stack), 0, message) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class CallStackTraceIntegrationTestCase(unittest.TestCase): + """Integration tests for call stack trace in client request methods.""" + + def setUp(self): + """Set up test fixtures.""" + self.client = SparkConnectClient("sc://localhost:15002", use_reattachable_execute=False) + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are + # not filtered out. + self.original_pyspark_root = core.PYSPARK_ROOT + core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) + + def tearDown(self): + # Restore the original PYSPARK_ROOT + core.PYSPARK_ROOT = self.original_pyspark_root + + def test_execute_plan_request_includes_call_stack_without_env_var(self): + """Test that _execute_plan_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._execute_plan_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_execute_plan_request_includes_call_stack_with_env_var(self): + """Test that _execute_plan_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._execute_plan_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as StackTraceElement + files = set() + functions = set() + for extension in req.user_context.extensions: + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + extension.Unpack(stack_trace_element) + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue("test_execute_plan_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_analyze_plan_request_includes_call_stack_without_env_var(self): + """Test that _analyze_plan_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._analyze_plan_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_analyze_plan_request_includes_call_stack_with_env_var(self): + """Test that _analyze_plan_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._analyze_plan_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as StackTraceElement + files = set() + functions = set() + for extension in req.user_context.extensions: + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + extension.Unpack(stack_trace_element) + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue("test_analyze_plan_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_config_request_includes_call_stack_without_env_var(self): + """Test that _config_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._config_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_config_request_includes_call_stack_with_env_var(self): + """Test that _config_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._config_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as StackTraceElement + files = set() + functions = set() + for extension in req.user_context.extensions: + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + extension.Unpack(stack_trace_element) + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue("test_config_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_call_stack_trace_captures_correct_calling_context(self): + """Test that call stack trace captures the correct calling context.""" + + def level3(): + """Third level function.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._execute_plan_request_with_metadata() + return req + + def level2(): + """Second level function.""" + return level3() + + def level1(): + """First level function.""" + return level2() + + req = level1() + + # Verify we captured frames from our nested functions + self.assertGreater(len(req.user_context.extensions), 0) + + # Unpack and check that we have function names from our call chain + functions = set() + files = set() + for extension in req.user_context.extensions: + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + extension.Unpack(stack_trace_element) + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertGreater(stack_trace_element.line_number, 0, f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}") + + self.assertTrue("level1" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue("level2" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue("level3" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.client.test_client_call_stack_trace import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) + From 2dfc9faaeea801c79503f6076f26d68002906304 Mon Sep 17 00:00:00 2001 From: susheel-aroskar Date: Thu, 13 Nov 2025 18:00:11 -0800 Subject: [PATCH 2/6] Add descriptive comment on test, link it to the Jira --- .../sql/tests/connect/client/test_client_call_stack_trace.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py index ff8e8bc0fbae..2ebd47c60d61 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -41,7 +41,9 @@ # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. SparkConnectClient._cleanup_ml_cache = lambda _: None - +# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing client application's +# file name and line numbers in PySpark +# https://issues.apache.org/jira/browse/SPARK-54314 @unittest.skipIf(not should_test_connect, connect_requirement_message) class CallStackTraceTestCase(unittest.TestCase): From f0bc7de7ce55dbdcdc095c6450e0e3d849fbb763 Mon Sep 17 00:00:00 2001 From: susheel-aroskar Date: Fri, 14 Nov 2025 14:21:47 -0800 Subject: [PATCH 3/6] Fix linter warnings --- python/pyspark/sql/connect/client/core.py | 6 +- .../client/test_client_call_stack_trace.py | 144 ++++++++++++------ 2 files changed, 99 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 5334f53239e0..49582f94ccd7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -612,6 +612,7 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": warnings=list(pb.warnings), ) + def _is_pyspark_source(filename: str) -> bool: """Check if the given filename is from the pyspark package.""" return filename.startswith(PYSPARK_ROOT) @@ -628,12 +629,13 @@ def _retrieve_stack_frames() -> List[CallSite]: filename, lineno, func, _ = frame if _is_pyspark_source(filename): break - if i+1 < len(frames): - _, _, func, _ = frames[i+1] + if i + 1 < len(frames): + _, _, func, _ = frames[i + 1] filtered_stack_frames.append(CallSite(function=func, file=filename, linenum=lineno)) return filtered_stack_frames + def _build_call_stack_trace() -> List[any_pb2.Any]: """ Build a call stack trace for the current Spark Connect action diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py index 2ebd47c60d61..c06fbdc7b1ac 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -35,29 +35,29 @@ from pyspark.traceback_utils import CallSite from google.protobuf import any_pb2 - # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. SparkConnectClient._cleanup_ml_cache = lambda _: None -# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing client application's +# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing client application's # file name and line numbers in PySpark # https://issues.apache.org/jira/browse/SPARK-54314 + @unittest.skipIf(not should_test_connect, connect_requirement_message) class CallStackTraceTestCase(unittest.TestCase): """Test cases for call stack trace functionality in Spark Connect client.""" def setUp(self): - # Since this test itself is under pyspark module path, stack frames for test functions inside - # this file - for example, user_function() - will normally be filtered out. So here we - # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this - # test file to ensure that the stack frames for user functions inside this test file are + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are # not filtered out. self.original_pyspark_root = core.PYSPARK_ROOT core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) - + def tearDown(self): # Restore the original PYSPARK_ROOT core.PYSPARK_ROOT = self.original_pyspark_root @@ -93,30 +93,38 @@ def test_is_pyspark_source_with_relative_path(self): def test_retrieve_stack_frames_filters_pyspark_frames(self): """Test that _retrieve_stack_frames filters out PySpark internal frames.""" + def user_function(): return _retrieve_stack_frames() stack_frames = user_function() - + # Verify we have at least some frames self.assertGreater(len(stack_frames), 0, "Expected at least some stack frames") - + # Verify that none of the returned frames are from PySpark internal code for frame in stack_frames: # Check that this frame is not from pyspark internal code self.assertFalse( _is_pyspark_source(frame.file), - f"Expected frame from {frame.file} (function: {frame.function}) to be filtered out as PySpark internal frame" + f"Expected frame from {frame.file} (function: {frame.function}) to be filtered out as PySpark internal frame", ) - + # Verify that user function names are present (confirming user frames are included) function_names = [frame.function for frame in stack_frames] expected_functions = ["user_function", "test_retrieve_stack_frames_filters_pyspark_frames"] - self.assertTrue("user_function" in function_names, f"Expected user function names not found in: {function_names}") - self.assertTrue("test_retrieve_stack_frames_filters_pyspark_frames" in function_names, f"Expected user function names not found in: {function_names}") + self.assertTrue( + "user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "test_retrieve_stack_frames_filters_pyspark_frames" in function_names, + f"Expected user function names not found in: {function_names}", + ) def test_retrieve_stack_frames_includes_user_frames(self): """Test that _retrieve_stack_frames includes user code frames.""" + def user_function(): """Simulate a user function.""" return _retrieve_stack_frames() @@ -124,7 +132,7 @@ def user_function(): def another_user_function(): """Another level of user code.""" return user_function() - + stack_frames = another_user_function() # We should have at least some frames from the test @@ -133,12 +141,22 @@ def another_user_function(): # Check that we have frames with function names we expect function_names = [frame.function for frame in stack_frames] # At least one of our test functions should be in the stack - self.assertTrue("user_function" in function_names, f"Expected user function names not found in: {function_names}") - self.assertTrue("another_user_function" in function_names, f"Expected user function names not found in: {function_names}") - self.assertTrue("test_retrieve_stack_frames_includes_user_frames" in function_names, f"Expected user function names not found in: {function_names}") + self.assertTrue( + "user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "another_user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "test_retrieve_stack_frames_includes_user_frames" in function_names, + f"Expected user function names not found in: {function_names}", + ) def test_retrieve_stack_frames_captures_correct_info(self): """Test that _retrieve_stack_frames captures correct frame information.""" + def user_function(): return _retrieve_stack_frames() @@ -157,9 +175,14 @@ def user_function(): self.assertIsInstance(frame.file, str) self.assertIsInstance(frame.linenum, int) self.assertGreater(frame.linenum, 0) - - self.assertTrue("user_function" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue("test_retrieve_stack_frames_captures_correct_info" in functions, f"Expected user function names not found in: {functions}") + + self.assertTrue( + "user_function" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "test_retrieve_stack_frames_captures_correct_info" in functions, + f"Expected user function names not found in: {functions}", + ) self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") def test_build_call_stack_trace_without_env_var(self): @@ -182,9 +205,7 @@ def test_build_call_stack_trace_with_env_var_set(self): self.assertIsInstance(call_stack, list) # Should have at least one frame (this test function) - self.assertGreater( - len(call_stack), 0, "Expected non-empty list when env var is set" - ) + self.assertGreater(len(call_stack), 0, "Expected non-empty list when env var is set") # Verify each element is an Any protobuf message functions = set() @@ -205,9 +226,14 @@ def test_build_call_stack_trace_with_env_var_set(self): self.assertEqual( stack_trace_element.declaring_class, "", "declaring_class should be empty" ) - - self.assertTrue("test_build_call_stack_trace_with_env_var_set" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + self.assertTrue( + "test_build_call_stack_trace_with_env_var_set" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) def test_build_call_stack_trace_with_env_var_empty_string(self): """Test that _build_call_stack_trace returns empty list when env var is empty string.""" @@ -215,9 +241,7 @@ def test_build_call_stack_trace_with_env_var_empty_string(self): call_stack = _build_call_stack_trace() self.assertIsInstance(call_stack, list) - self.assertEqual( - len(call_stack), 0, "Expected empty list when env var is empty string" - ) + self.assertEqual(len(call_stack), 0, "Expected empty list when env var is empty string") def test_build_call_stack_trace_with_various_env_var_values(self): """Test _build_call_stack_trace behavior with various env var values.""" @@ -232,9 +256,7 @@ def test_build_call_stack_trace_with_various_env_var_values(self): for env_value, expected_behavior, message in test_cases: with self.subTest(env_value=env_value): - with patch.dict( - os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": env_value} - ): + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": env_value}): call_stack = _build_call_stack_trace() if expected_behavior == 0: self.assertEqual(len(call_stack), 0, message) @@ -249,14 +271,14 @@ class CallStackTraceIntegrationTestCase(unittest.TestCase): def setUp(self): """Set up test fixtures.""" self.client = SparkConnectClient("sc://localhost:15002", use_reattachable_execute=False) - # Since this test itself is under pyspark module path, stack frames for test functions inside - # this file - for example, user_function() - will normally be filtered out. So here we - # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this - # test file to ensure that the stack frames for user functions inside this test file are + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are # not filtered out. self.original_pyspark_root = core.PYSPARK_ROOT core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) - + def tearDown(self): # Restore the original PYSPARK_ROOT core.PYSPARK_ROOT = self.original_pyspark_root @@ -298,9 +320,14 @@ def test_execute_plan_request_includes_call_stack_with_env_var(self): files.add(stack_trace_element.file_name) self.assertIsInstance(stack_trace_element.method_name, str) self.assertIsInstance(stack_trace_element.file_name, str) - - self.assertTrue("test_execute_plan_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + self.assertTrue( + "test_execute_plan_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) def test_analyze_plan_request_includes_call_stack_without_env_var(self): """Test that _analyze_plan_request_with_metadata doesn't include call stack without env var.""" @@ -340,8 +367,13 @@ def test_analyze_plan_request_includes_call_stack_with_env_var(self): self.assertIsInstance(stack_trace_element.method_name, str) self.assertIsInstance(stack_trace_element.file_name, str) - self.assertTrue("test_analyze_plan_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + self.assertTrue( + "test_analyze_plan_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) def test_config_request_includes_call_stack_without_env_var(self): """Test that _config_request_with_metadata doesn't include call stack without env var.""" @@ -381,8 +413,13 @@ def test_config_request_includes_call_stack_with_env_var(self): self.assertIsInstance(stack_trace_element.method_name, str) self.assertIsInstance(stack_trace_element.file_name, str) - self.assertTrue("test_config_request_includes_call_stack_with_env_var" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + self.assertTrue( + "test_config_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) def test_call_stack_trace_captures_correct_calling_context(self): """Test that call stack trace captures the correct calling context.""" @@ -414,11 +451,21 @@ def level1(): extension.Unpack(stack_trace_element) functions.add(stack_trace_element.method_name) files.add(stack_trace_element.file_name) - self.assertGreater(stack_trace_element.line_number, 0, f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}") + self.assertGreater( + stack_trace_element.line_number, + 0, + f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}", + ) - self.assertTrue("level1" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue("level2" in functions, f"Expected user function names not found in: {functions}") - self.assertTrue("level3" in functions, f"Expected user function names not found in: {functions}") + self.assertTrue( + "level1" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "level2" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "level3" in functions, f"Expected user function names not found in: {functions}" + ) self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") @@ -432,4 +479,3 @@ def level1(): except ImportError: testRunner = None unittest.main(testRunner=testRunner, verbosity=2) - From ed5ceb01535e1284a2ae363656012959f06dd94b Mon Sep 17 00:00:00 2001 From: susheel-aroskar Date: Tue, 18 Nov 2025 17:45:37 -0800 Subject: [PATCH 4/6] Add individual call stack frames to a wrapper error object before adding them to user_context.extensions --- python/pyspark/sql/connect/client/core.py | 25 ++--- .../client/test_client_call_stack_trace.py | 98 ++++++++++--------- 2 files changed, 64 insertions(+), 59 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 49582f94ccd7..4c3c14d950ea 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -23,6 +23,7 @@ import atexit import pyspark +from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -636,26 +637,28 @@ def _retrieve_stack_frames() -> List[CallSite]: return filtered_stack_frames -def _build_call_stack_trace() -> List[any_pb2.Any]: +def _build_call_stack_trace() -> any_pb2.Any: """ Build a call stack trace for the current Spark Connect action Returns ------- - List[any_pb2.Any]: A list of Any objects, each representing a stack frame in the call stack trace in the user code. + FetchErrorDetailsResponse.Error: An Error object containing list of stack frames of the user code packed as Any protobuf. """ - call_stack_trace = [] if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() in ("true", "1"): stack_frames = _retrieve_stack_frames() - for i, call_site in enumerate(stack_frames): + call_stack = FetchErrorDetailsResponse.Error() + for call_site in stack_frames: stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() stack_trace_element.declaring_class = "" # unknown information stack_trace_element.method_name = call_site.function stack_trace_element.file_name = call_site.file stack_trace_element.line_number = call_site.linenum - stack_frame = any_pb2.Any() - stack_frame.Pack(stack_trace_element) - call_stack_trace.append(stack_frame) - return call_stack_trace + call_stack.stack_trace.append(stack_trace_element) + if len(call_stack.stack_trace) > 0: + call_stack_details = any_pb2.Any() + call_stack_details.Pack(call_stack) + return call_stack_details + return None class SparkConnectClient(object): @@ -1386,7 +1389,7 @@ def _execute_plan_request_with_metadata( call_stack_trace = _build_call_stack_trace() if call_stack_trace: - req.user_context.extensions.extend(call_stack_trace) + req.user_context.extensions.append(call_stack_trace) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1400,7 +1403,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: self._update_request_with_user_context_extensions(req) call_stack_trace = _build_call_stack_trace() if call_stack_trace: - req.user_context.extensions.extend(call_stack_trace) + req.user_context.extensions.append(call_stack_trace) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1818,7 +1821,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: self._update_request_with_user_context_extensions(req) call_stack_trace = _build_call_stack_trace() if call_stack_trace: - req.user_context.extensions.extend(call_stack_trace) + req.user_context.extensions.append(call_stack_trace) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py index c06fbdc7b1ac..dda39346fe2a 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -22,6 +22,7 @@ import pyspark from pyspark.sql import functions +from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse from pyspark.testing.connectutils import should_test_connect, connect_requirement_message if should_test_connect: @@ -193,29 +194,24 @@ def test_build_call_stack_trace_without_env_var(self): del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] call_stack = _build_call_stack_trace() - - self.assertIsInstance(call_stack, list) - self.assertEqual(len(call_stack), 0, "Expected empty list when env var is not set") + self.assertIsNone(call_stack, "Expected None when env var is not set") def test_build_call_stack_trace_with_env_var_set(self): """Test that _build_call_stack_trace builds trace when env var is set.""" # Set the env var to enable call stack tracing with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): - call_stack = _build_call_stack_trace() - - self.assertIsInstance(call_stack, list) + stack_trace_details = _build_call_stack_trace() + self.assertIsNotNone(stack_trace_details, "Expected non-None call stack when env var is set") + error = pb2.FetchErrorDetailsResponse.Error() + if not stack_trace_details.Unpack(error): + self.assertTrue(False, "Expected to unpack stack trace details into Error object") # Should have at least one frame (this test function) - self.assertGreater(len(call_stack), 0, "Expected non-empty list when env var is set") + self.assertGreater(len(error.stack_trace), 0, "Expected > 0 call stack frames when env var is set") # Verify each element is an Any protobuf message functions = set() files = set() - for stack_frame in call_stack: - self.assertIsInstance(stack_frame, any_pb2.Any) - - # Unpack and verify it's a StackTraceElement - stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() - stack_frame.Unpack(stack_trace_element) + for stack_trace_element in error.stack_trace: functions.add(stack_trace_element.method_name) files.add(stack_trace_element.file_name) @@ -239,9 +235,7 @@ def test_build_call_stack_trace_with_env_var_empty_string(self): """Test that _build_call_stack_trace returns empty list when env var is empty string.""" with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": ""}): call_stack = _build_call_stack_trace() - - self.assertIsInstance(call_stack, list) - self.assertEqual(len(call_stack), 0, "Expected empty list when env var is empty string") + self.assertIsNone(call_stack, "Expected empty list when env var is empty string") def test_build_call_stack_trace_with_various_env_var_values(self): """Test _build_call_stack_trace behavior with various env var values.""" @@ -259,9 +253,9 @@ def test_build_call_stack_trace_with_various_env_var_values(self): with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": env_value}): call_stack = _build_call_stack_trace() if expected_behavior == 0: - self.assertEqual(len(call_stack), 0, message) + self.assertIsNone(call_stack, message) else: - self.assertGreater(len(call_stack), 0, message) + self.assertIsNotNone(call_stack, message) @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -310,16 +304,18 @@ def test_execute_plan_request_includes_call_stack_with_env_var(self): "Expected extensions with env var set", ) - # Verify each extension can be unpacked as StackTraceElement + # Verify each extension can be unpacked as Error containing StackTraceElements files = set() functions = set() for extension in req.user_context.extensions: - stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() - extension.Unpack(stack_trace_element) - functions.add(stack_trace_element.method_name) - files.add(stack_trace_element.file_name) - self.assertIsInstance(stack_trace_element.method_name, str) - self.assertIsInstance(stack_trace_element.file_name, str) + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) self.assertTrue( "test_execute_plan_request_includes_call_stack_with_env_var" in functions, @@ -356,16 +352,18 @@ def test_analyze_plan_request_includes_call_stack_with_env_var(self): "Expected extensions with env var set", ) - # Verify each extension can be unpacked as StackTraceElement + # Verify each extension can be unpacked as Error containing StackTraceElements files = set() functions = set() for extension in req.user_context.extensions: - stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() - extension.Unpack(stack_trace_element) - functions.add(stack_trace_element.method_name) - files.add(stack_trace_element.file_name) - self.assertIsInstance(stack_trace_element.method_name, str) - self.assertIsInstance(stack_trace_element.file_name, str) + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) self.assertTrue( "test_analyze_plan_request_includes_call_stack_with_env_var" in functions, @@ -402,16 +400,18 @@ def test_config_request_includes_call_stack_with_env_var(self): "Expected extensions with env var set", ) - # Verify each extension can be unpacked as StackTraceElement + # Verify each extension can be unpacked as Error containing StackTraceElements files = set() functions = set() for extension in req.user_context.extensions: - stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() - extension.Unpack(stack_trace_element) - functions.add(stack_trace_element.method_name) - files.add(stack_trace_element.file_name) - self.assertIsInstance(stack_trace_element.method_name, str) - self.assertIsInstance(stack_trace_element.file_name, str) + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) self.assertTrue( "test_config_request_includes_call_stack_with_env_var" in functions, @@ -447,15 +447,17 @@ def level1(): functions = set() files = set() for extension in req.user_context.extensions: - stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() - extension.Unpack(stack_trace_element) - functions.add(stack_trace_element.method_name) - files.add(stack_trace_element.file_name) - self.assertGreater( - stack_trace_element.line_number, - 0, - f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}", - ) + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertGreater( + stack_trace_element.line_number, + 0, + f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}", + ) self.assertTrue( "level1" in functions, f"Expected user function names not found in: {functions}" From 5db628d198b938e8565ad1febe7b21ff25d1f5df Mon Sep 17 00:00:00 2001 From: susheel-aroskar Date: Tue, 18 Nov 2025 17:54:08 -0800 Subject: [PATCH 5/6] Mollify linter --- python/pyspark/sql/connect/client/core.py | 2 -- .../tests/connect/client/test_client_call_stack_trace.py | 8 ++++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 4c3c14d950ea..d74b147dae60 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1335,7 +1335,6 @@ def token(self) -> Optional[str]: """ return self._builder.token - def _update_request_with_user_context_extensions( self, req: Union[ @@ -1354,7 +1353,6 @@ def _update_request_with_user_context_extensions( for _, extension in self.thread_local.user_context_extensions: req.user_context.extensions.append(extension) - def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py index dda39346fe2a..8c5fd1adb862 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -201,12 +201,16 @@ def test_build_call_stack_trace_with_env_var_set(self): # Set the env var to enable call stack tracing with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): stack_trace_details = _build_call_stack_trace() - self.assertIsNotNone(stack_trace_details, "Expected non-None call stack when env var is set") + self.assertIsNotNone( + stack_trace_details, "Expected non-None call stack when env var is set" + ) error = pb2.FetchErrorDetailsResponse.Error() if not stack_trace_details.Unpack(error): self.assertTrue(False, "Expected to unpack stack trace details into Error object") # Should have at least one frame (this test function) - self.assertGreater(len(error.stack_trace), 0, "Expected > 0 call stack frames when env var is set") + self.assertGreater( + len(error.stack_trace), 0, "Expected > 0 call stack frames when env var is set" + ) # Verify each element is an Any protobuf message functions = set() From 69a5e0823b94c30b781421b78b82619a616f6fab Mon Sep 17 00:00:00 2001 From: susheel-aroskar Date: Tue, 18 Nov 2025 18:02:33 -0800 Subject: [PATCH 6/6] Incorporate code review comments --- python/pyspark/sql/connect/client/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index d74b147dae60..9da3d3a06922 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -629,6 +629,7 @@ def _retrieve_stack_frames() -> List[CallSite]: for i, frame in enumerate(frames): filename, lineno, func, _ = frame if _is_pyspark_source(filename): + # Do not include PySpark internal frames as they are not user application code break if i + 1 < len(frames): _, _, func, _ = frames[i + 1]