diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 48e07642e157..9da3d3a06922 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -22,6 +22,8 @@ import atexit +import pyspark +from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -35,6 +37,7 @@ import uuid import sys import time +import traceback from typing import ( Iterable, Iterator, @@ -65,6 +68,8 @@ from pyspark.util import is_remote_only from pyspark.accumulators import SpecialAccumulatorIds from pyspark.version import __version__ +from pyspark import traceback_utils +from pyspark.traceback_utils import CallSite from pyspark.resource.information import ResourceInformation from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager @@ -114,6 +119,9 @@ from pyspark.sql.datasource import DataSource +PYSPARK_ROOT = os.path.dirname(pyspark.__file__) + + def _import_zstandard_if_available() -> Optional[Any]: """ Import zstandard if available, otherwise return None. @@ -606,6 +614,54 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": ) +def _is_pyspark_source(filename: str) -> bool: + """Check if the given filename is from the pyspark package.""" + return filename.startswith(PYSPARK_ROOT) + + +def _retrieve_stack_frames() -> List[CallSite]: + """ + Return a list of CallSites representing the relevant stack frames in the callstack. + """ + frames = traceback.extract_stack() + + filtered_stack_frames = [] + for i, frame in enumerate(frames): + filename, lineno, func, _ = frame + if _is_pyspark_source(filename): + # Do not include PySpark internal frames as they are not user application code + break + if i + 1 < len(frames): + _, _, func, _ = frames[i + 1] + filtered_stack_frames.append(CallSite(function=func, file=filename, linenum=lineno)) + + return filtered_stack_frames + + +def _build_call_stack_trace() -> any_pb2.Any: + """ + Build a call stack trace for the current Spark Connect action + Returns + ------- + FetchErrorDetailsResponse.Error: An Error object containing list of stack frames of the user code packed as Any protobuf. + """ + if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() in ("true", "1"): + stack_frames = _retrieve_stack_frames() + call_stack = FetchErrorDetailsResponse.Error() + for call_site in stack_frames: + stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement() + stack_trace_element.declaring_class = "" # unknown information + stack_trace_element.method_name = call_site.function + stack_trace_element.file_name = call_site.file + stack_trace_element.line_number = call_site.linenum + call_stack.stack_trace.append(stack_trace_element) + if len(call_stack.stack_trace) > 0: + call_stack_details = any_pb2.Any() + call_stack_details.Pack(call_stack) + return call_stack_details + return None + + class SparkConnectClient(object): """ Conceptually the remote spark session that communicates with the server @@ -1329,6 +1385,10 @@ def _execute_plan_request_with_metadata( ) req.operation_id = operation_id self._update_request_with_user_context_extensions(req) + + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.append(call_stack_trace) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1340,6 +1400,9 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: if self._user_id: req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.append(call_stack_trace) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1755,6 +1818,9 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: if self._user_id: req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) + call_stack_trace = _build_call_stack_trace() + if call_stack_trace: + req.user_context.extensions.append(call_stack_trace) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: diff --git a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py new file mode 100644 index 000000000000..8c5fd1adb862 --- /dev/null +++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import unittest +from unittest.mock import patch + +import pyspark +from pyspark.sql import functions +from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message + +if should_test_connect: + import pyspark.sql.connect.proto as pb2 + from pyspark.sql.connect.client import SparkConnectClient, core + from pyspark.sql.connect.client.core import ( + _is_pyspark_source, + _retrieve_stack_frames, + _build_call_stack_trace, + ) + from pyspark.traceback_utils import CallSite + from google.protobuf import any_pb2 + + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) + # and it blocks the test process exiting because it is registered as the atexit handler + # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. + SparkConnectClient._cleanup_ml_cache = lambda _: None + +# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing client application's +# file name and line numbers in PySpark +# https://issues.apache.org/jira/browse/SPARK-54314 + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class CallStackTraceTestCase(unittest.TestCase): + """Test cases for call stack trace functionality in Spark Connect client.""" + + def setUp(self): + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are + # not filtered out. + self.original_pyspark_root = core.PYSPARK_ROOT + core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) + + def tearDown(self): + # Restore the original PYSPARK_ROOT + core.PYSPARK_ROOT = self.original_pyspark_root + + def test_is_pyspark_source_with_pyspark_file(self): + """Test that _is_pyspark_source correctly identifies PySpark files.""" + # Get a known pyspark file path + from pyspark import sql + + pyspark_file = sql.connect.client.__file__ + self.assertTrue(_is_pyspark_source(pyspark_file)) + + def test_is_pyspark_source_with_non_pyspark_file(self): + """Test that _is_pyspark_source correctly identifies non-PySpark files.""" + # Use the current test file which is in pyspark but we'll simulate a non-pyspark path + non_pyspark_file = "/tmp/user_script.py" + self.assertFalse(_is_pyspark_source(non_pyspark_file)) + + # Test with stdlib file + stdlib_file = os.__file__ + self.assertFalse(_is_pyspark_source(stdlib_file)) + + def test_is_pyspark_source_with_relative_path(self): + """Test _is_pyspark_source with various path formats.""" + from pyspark import sql + + # Test with absolute path to pyspark file + pyspark_sql_file = sql.connect.client.__file__ + self.assertTrue(_is_pyspark_source(pyspark_sql_file)) + + # Test with non-pyspark absolute path + self.assertFalse(_is_pyspark_source("/home/user/my_script.py")) + + def test_retrieve_stack_frames_filters_pyspark_frames(self): + """Test that _retrieve_stack_frames filters out PySpark internal frames.""" + + def user_function(): + return _retrieve_stack_frames() + + stack_frames = user_function() + + # Verify we have at least some frames + self.assertGreater(len(stack_frames), 0, "Expected at least some stack frames") + + # Verify that none of the returned frames are from PySpark internal code + for frame in stack_frames: + # Check that this frame is not from pyspark internal code + self.assertFalse( + _is_pyspark_source(frame.file), + f"Expected frame from {frame.file} (function: {frame.function}) to be filtered out as PySpark internal frame", + ) + + # Verify that user function names are present (confirming user frames are included) + function_names = [frame.function for frame in stack_frames] + expected_functions = ["user_function", "test_retrieve_stack_frames_filters_pyspark_frames"] + self.assertTrue( + "user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "test_retrieve_stack_frames_filters_pyspark_frames" in function_names, + f"Expected user function names not found in: {function_names}", + ) + + def test_retrieve_stack_frames_includes_user_frames(self): + """Test that _retrieve_stack_frames includes user code frames.""" + + def user_function(): + """Simulate a user function.""" + return _retrieve_stack_frames() + + def another_user_function(): + """Another level of user code.""" + return user_function() + + stack_frames = another_user_function() + + # We should have at least some frames from the test + self.assertGreater(len(stack_frames), 0) + + # Check that we have frames with function names we expect + function_names = [frame.function for frame in stack_frames] + # At least one of our test functions should be in the stack + self.assertTrue( + "user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "another_user_function" in function_names, + f"Expected user function names not found in: {function_names}", + ) + self.assertTrue( + "test_retrieve_stack_frames_includes_user_frames" in function_names, + f"Expected user function names not found in: {function_names}", + ) + + def test_retrieve_stack_frames_captures_correct_info(self): + """Test that _retrieve_stack_frames captures correct frame information.""" + + def user_function(): + return _retrieve_stack_frames() + + stack_frames = user_function() + + # Verify each frame has the expected attributes + functions = set() + files = set() + for frame in stack_frames: + functions.add(frame.function) + files.add(frame.file) + self.assertIsNotNone(frame.function) + self.assertIsNotNone(frame.file) + self.assertIsNotNone(frame.linenum) + self.assertIsInstance(frame.function, str) + self.assertIsInstance(frame.file, str) + self.assertIsInstance(frame.linenum, int) + self.assertGreater(frame.linenum, 0) + + self.assertTrue( + "user_function" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "test_retrieve_stack_frames_captures_correct_info" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + def test_build_call_stack_trace_without_env_var(self): + """Test that _build_call_stack_trace returns empty list when env var is not set.""" + # Make sure the env var is not set + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + call_stack = _build_call_stack_trace() + self.assertIsNone(call_stack, "Expected None when env var is not set") + + def test_build_call_stack_trace_with_env_var_set(self): + """Test that _build_call_stack_trace builds trace when env var is set.""" + # Set the env var to enable call stack tracing + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + stack_trace_details = _build_call_stack_trace() + self.assertIsNotNone( + stack_trace_details, "Expected non-None call stack when env var is set" + ) + error = pb2.FetchErrorDetailsResponse.Error() + if not stack_trace_details.Unpack(error): + self.assertTrue(False, "Expected to unpack stack trace details into Error object") + # Should have at least one frame (this test function) + self.assertGreater( + len(error.stack_trace), 0, "Expected > 0 call stack frames when env var is set" + ) + + # Verify each element is an Any protobuf message + functions = set() + files = set() + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + + # Verify the fields are populated + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + self.assertIsInstance(stack_trace_element.line_number, int) + self.assertEqual( + stack_trace_element.declaring_class, "", "declaring_class should be empty" + ) + + self.assertTrue( + "test_build_call_stack_trace_with_env_var_set" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) + + def test_build_call_stack_trace_with_env_var_empty_string(self): + """Test that _build_call_stack_trace returns empty list when env var is empty string.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": ""}): + call_stack = _build_call_stack_trace() + self.assertIsNone(call_stack, "Expected empty list when env var is empty string") + + def test_build_call_stack_trace_with_various_env_var_values(self): + """Test _build_call_stack_trace behavior with various env var values.""" + test_cases = [ + ("0", 0, "zero string should be treated as falsy"), + ("false", 0, "non-empty string 'false' should be falsy"), + ("TRUE", 1, "string 'TRUE' should be truthy"), + ("true", 1, "string 'true' should be truthy"), + ("1", 1, "string '1' should be truthy"), + ("any_value", 0, "any non-empty string should be falsy"), + ] + + for env_value, expected_behavior, message in test_cases: + with self.subTest(env_value=env_value): + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": env_value}): + call_stack = _build_call_stack_trace() + if expected_behavior == 0: + self.assertIsNone(call_stack, message) + else: + self.assertIsNotNone(call_stack, message) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class CallStackTraceIntegrationTestCase(unittest.TestCase): + """Integration tests for call stack trace in client request methods.""" + + def setUp(self): + """Set up test fixtures.""" + self.client = SparkConnectClient("sc://localhost:15002", use_reattachable_execute=False) + # Since this test itself is under pyspark module path, stack frames for test functions inside + # this file - for example, user_function() - will normally be filtered out. So here we + # set the PYSPARK_ROOT to more specific pyspaark.sql.connect that doesn't include this + # test file to ensure that the stack frames for user functions inside this test file are + # not filtered out. + self.original_pyspark_root = core.PYSPARK_ROOT + core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__) + + def tearDown(self): + # Restore the original PYSPARK_ROOT + core.PYSPARK_ROOT = self.original_pyspark_root + + def test_execute_plan_request_includes_call_stack_without_env_var(self): + """Test that _execute_plan_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._execute_plan_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_execute_plan_request_includes_call_stack_with_env_var(self): + """Test that _execute_plan_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._execute_plan_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as Error containing StackTraceElements + files = set() + functions = set() + for extension in req.user_context.extensions: + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue( + "test_execute_plan_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) + + def test_analyze_plan_request_includes_call_stack_without_env_var(self): + """Test that _analyze_plan_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._analyze_plan_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_analyze_plan_request_includes_call_stack_with_env_var(self): + """Test that _analyze_plan_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._analyze_plan_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as Error containing StackTraceElements + files = set() + functions = set() + for extension in req.user_context.extensions: + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue( + "test_analyze_plan_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) + + def test_config_request_includes_call_stack_without_env_var(self): + """Test that _config_request_with_metadata doesn't include call stack without env var.""" + with patch.dict(os.environ, {}, clear=False): + if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ: + del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"] + + req = self.client._config_request_with_metadata() + + # Should have no extensions when env var is not set + self.assertEqual( + len(req.user_context.extensions), + 0, + "Expected no extensions without env var", + ) + + def test_config_request_includes_call_stack_with_env_var(self): + """Test that _config_request_with_metadata includes call stack with env var.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._config_request_with_metadata() + + # Should have extensions when env var is set + self.assertGreater( + len(req.user_context.extensions), + 0, + "Expected extensions with env var set", + ) + + # Verify each extension can be unpacked as Error containing StackTraceElements + files = set() + functions = set() + for extension in req.user_context.extensions: + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertIsInstance(stack_trace_element.method_name, str) + self.assertIsInstance(stack_trace_element.file_name, str) + + self.assertTrue( + "test_config_request_includes_call_stack_with_env_var" in functions, + f"Expected user function names not found in: {functions}", + ) + self.assertTrue( + __file__ in files, f"Expected user function names not found in: {files}" + ) + + def test_call_stack_trace_captures_correct_calling_context(self): + """Test that call stack trace captures the correct calling context.""" + + def level3(): + """Third level function.""" + with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}): + req = self.client._execute_plan_request_with_metadata() + return req + + def level2(): + """Second level function.""" + return level3() + + def level1(): + """First level function.""" + return level2() + + req = level1() + + # Verify we captured frames from our nested functions + self.assertGreater(len(req.user_context.extensions), 0) + + # Unpack and check that we have function names from our call chain + functions = set() + files = set() + for extension in req.user_context.extensions: + error = pb2.FetchErrorDetailsResponse.Error() + if extension.Unpack(error): + # Process stack trace elements within the Error + for stack_trace_element in error.stack_trace: + functions.add(stack_trace_element.method_name) + files.add(stack_trace_element.file_name) + self.assertGreater( + stack_trace_element.line_number, + 0, + f"Expected line number to be greater than 0, got: {stack_trace_element.line_number}", + ) + + self.assertTrue( + "level1" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "level2" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue( + "level3" in functions, f"Expected user function names not found in: {functions}" + ) + self.assertTrue(__file__ in files, f"Expected user function names not found in: {files}") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.client.test_client_call_stack_trace import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)