Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import atexit

import pyspark
from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand All @@ -35,6 +37,7 @@
import uuid
import sys
import time
import traceback
from typing import (
Iterable,
Iterator,
Expand Down Expand Up @@ -65,6 +68,8 @@
from pyspark.util import is_remote_only
from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.version import __version__
from pyspark import traceback_utils
from pyspark.traceback_utils import CallSite
from pyspark.resource.information import ResourceInformation
from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
Expand Down Expand Up @@ -114,6 +119,9 @@
from pyspark.sql.datasource import DataSource


PYSPARK_ROOT = os.path.dirname(pyspark.__file__)


def _import_zstandard_if_available() -> Optional[Any]:
"""
Import zstandard if available, otherwise return None.
Expand Down Expand Up @@ -606,6 +614,54 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


def _is_pyspark_source(filename: str) -> bool:
"""Check if the given filename is from the pyspark package."""
return filename.startswith(PYSPARK_ROOT)


def _retrieve_stack_frames() -> List[CallSite]:
"""
Return a list of CallSites representing the relevant stack frames in the callstack.
"""
frames = traceback.extract_stack()

filtered_stack_frames = []
for i, frame in enumerate(frames):
filename, lineno, func, _ = frame
if _is_pyspark_source(filename):
# Do not include PySpark internal frames as they are not user application code
break
if i + 1 < len(frames):
_, _, func, _ = frames[i + 1]
filtered_stack_frames.append(CallSite(function=func, file=filename, linenum=lineno))

return filtered_stack_frames


def _build_call_stack_trace() -> any_pb2.Any:
"""
Build a call stack trace for the current Spark Connect action
Returns
-------
FetchErrorDetailsResponse.Error: An Error object containing list of stack frames of the user code packed as Any protobuf.
"""
if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() in ("true", "1"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why system env variable instead of Spark configuration flag? Also if we're adding a new configuration option we should probably document it somewhere (if it's a spark conf flag we have the doc in-line sort of already)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel Spark Configuration flags are mainly for things that affect Spark's - i.e. Spark server's - behavior. This only affects client's behavior. Also if I set it in Spark conf then I'll have to make a network call spark.conf.get() every time to decide whether to include the client code call stack or not.
Anyway, that was my thinking but I'm open to changing it to Spark conf if that's the convention / best practice.

stack_frames = _retrieve_stack_frames()
call_stack = FetchErrorDetailsResponse.Error()
for call_site in stack_frames:
stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement()
stack_trace_element.declaring_class = "" # unknown information
stack_trace_element.method_name = call_site.function
stack_trace_element.file_name = call_site.file
stack_trace_element.line_number = call_site.linenum
call_stack.stack_trace.append(stack_trace_element)
if len(call_stack.stack_trace) > 0:
call_stack_details = any_pb2.Any()
call_stack_details.Pack(call_stack)
return call_stack_details
return None


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
Expand Down Expand Up @@ -1329,6 +1385,10 @@ def _execute_plan_request_with_metadata(
)
req.operation_id = operation_id
self._update_request_with_user_context_extensions(req)

call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.append(call_stack_trace)
return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1340,6 +1400,9 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.append(call_stack_trace)
return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1755,6 +1818,9 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.append(call_stack_trace)
return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down
Loading