Skip to content

Commit 6227fba

Browse files
antbanueshin
authored andcommitted
[SPARK-54349][PYTHON] Refactor code a bit to simplify faulthandler integration extension
### What changes were proposed in this pull request? There are many places where pyspark is trying to integrate with faulthandler and use the same functionality to dump stack traces/record thread dumps. In order to reduce the complexity of the integration and ease the extension of integration it makes sense to technically refactor the code to use the same code in all the places. ### Why are the changes needed? Improves developer experience. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By existing unit tests ### Was this patch authored or co-authored using generative AI tooling? Closes #53016 from antban/simplify-faulthandler-integration. Authored-by: antban <dmitry.sorokin@gmail.com> Signed-off-by: Takuya Ueshin <ueshin@databricks.com>
1 parent 894a7e8 commit 6227fba

10 files changed

+138
-173
lines changed

python/pyspark/sql/worker/analyze_udtf.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
import faulthandler
1918
import inspect
2019
import os
2120
import sys
@@ -35,7 +34,12 @@
3534
from pyspark.sql.functions import OrderingColumn, PartitioningColumn, SelectedColumn
3635
from pyspark.sql.types import _parse_datatype_json_string, StructType
3736
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
38-
from pyspark.util import handle_worker_exception, local_connect_and_auth
37+
from pyspark.util import (
38+
handle_worker_exception,
39+
local_connect_and_auth,
40+
with_faulthandler,
41+
start_faulthandler_periodic_traceback,
42+
)
3943
from pyspark.worker_util import (
4044
check_python_version,
4145
read_command,
@@ -100,6 +104,7 @@ def read_arguments(infile: IO) -> Tuple[List[AnalyzeArgument], Dict[str, Analyze
100104
return args, kwargs
101105

102106

107+
@with_faulthandler
103108
def main(infile: IO, outfile: IO) -> None:
104109
"""
105110
Runs the Python UDTF's `analyze` static method.
@@ -108,18 +113,10 @@ def main(infile: IO, outfile: IO) -> None:
108113
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
109114
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
110115
"""
111-
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
112-
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
113116
try:
114-
if faulthandler_log_path:
115-
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
116-
faulthandler_log_file = open(faulthandler_log_path, "w")
117-
faulthandler.enable(file=faulthandler_log_file)
118-
119117
check_python_version(infile)
120118

121-
if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
122-
faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
119+
start_faulthandler_periodic_traceback()
123120

124121
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
125122
setup_memory_limits(memory_limit_mb)
@@ -266,11 +263,6 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
266263
except BaseException as e:
267264
handle_worker_exception(e, outfile)
268265
sys.exit(-1)
269-
finally:
270-
if faulthandler_log_path:
271-
faulthandler.disable()
272-
faulthandler_log_file.close()
273-
os.remove(faulthandler_log_path)
274266

275267
send_accumulator_updates(outfile)
276268

@@ -282,9 +274,6 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
282274
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
283275
sys.exit(-1)
284276

285-
# Force to cancel dump_traceback_later
286-
faulthandler.cancel_dump_traceback_later()
287-
288277

289278
if __name__ == "__main__":
290279
# Read information about how to connect back to the JVM from the environment.

python/pyspark/sql/worker/commit_data_source_write.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
import faulthandler
1817
import os
1918
import sys
2019
from typing import IO
@@ -29,7 +28,12 @@
2928
SpecialLengths,
3029
)
3130
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
32-
from pyspark.util import handle_worker_exception, local_connect_and_auth
31+
from pyspark.util import (
32+
handle_worker_exception,
33+
local_connect_and_auth,
34+
with_faulthandler,
35+
start_faulthandler_periodic_traceback,
36+
)
3337
from pyspark.worker_util import (
3438
check_python_version,
3539
pickleSer,
@@ -40,6 +44,7 @@
4044
)
4145

4246

47+
@with_faulthandler
4348
def main(infile: IO, outfile: IO) -> None:
4449
"""
4550
Main method for committing or aborting a data source write operation.
@@ -49,18 +54,10 @@ def main(infile: IO, outfile: IO) -> None:
4954
responsible for invoking either the `commit` or the `abort` method on a data source
5055
writer instance, given a list of commit messages.
5156
"""
52-
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
53-
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
5457
try:
55-
if faulthandler_log_path:
56-
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
57-
faulthandler_log_file = open(faulthandler_log_path, "w")
58-
faulthandler.enable(file=faulthandler_log_file)
59-
6058
check_python_version(infile)
6159

62-
if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
63-
faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
60+
start_faulthandler_periodic_traceback()
6461

6562
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
6663
setup_memory_limits(memory_limit_mb)
@@ -106,11 +103,6 @@ def main(infile: IO, outfile: IO) -> None:
106103
except BaseException as e:
107104
handle_worker_exception(e, outfile)
108105
sys.exit(-1)
109-
finally:
110-
if faulthandler_log_path:
111-
faulthandler.disable()
112-
faulthandler_log_file.close()
113-
os.remove(faulthandler_log_path)
114106

115107
send_accumulator_updates(outfile)
116108

@@ -122,9 +114,6 @@ def main(infile: IO, outfile: IO) -> None:
122114
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
123115
sys.exit(-1)
124116

125-
# Force to cancel dump_traceback_later
126-
faulthandler.cancel_dump_traceback_later()
127-
128117

129118
if __name__ == "__main__":
130119
# Read information about how to connect back to the JVM from the environment.

python/pyspark/sql/worker/create_data_source.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
import faulthandler
1817
import inspect
1918
import os
2019
import sys
@@ -32,7 +31,12 @@
3231
)
3332
from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
3433
from pyspark.sql.types import _parse_datatype_json_string, StructType
35-
from pyspark.util import handle_worker_exception, local_connect_and_auth
34+
from pyspark.util import (
35+
handle_worker_exception,
36+
local_connect_and_auth,
37+
with_faulthandler,
38+
start_faulthandler_periodic_traceback,
39+
)
3640
from pyspark.worker_util import (
3741
check_python_version,
3842
read_command,
@@ -45,6 +49,7 @@
4549
)
4650

4751

52+
@with_faulthandler
4853
def main(infile: IO, outfile: IO) -> None:
4954
"""
5055
Main method for creating a Python data source instance.
@@ -62,18 +67,10 @@ def main(infile: IO, outfile: IO) -> None:
6267
This process then creates a `DataSource` instance using the above information and
6368
sends the pickled instance as well as the schema back to the JVM.
6469
"""
65-
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
66-
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
6770
try:
68-
if faulthandler_log_path:
69-
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
70-
faulthandler_log_file = open(faulthandler_log_path, "w")
71-
faulthandler.enable(file=faulthandler_log_file)
72-
7371
check_python_version(infile)
7472

75-
if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
76-
faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
73+
start_faulthandler_periodic_traceback()
7774

7875
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
7976
setup_memory_limits(memory_limit_mb)
@@ -172,11 +169,6 @@ def main(infile: IO, outfile: IO) -> None:
172169
except BaseException as e:
173170
handle_worker_exception(e, outfile)
174171
sys.exit(-1)
175-
finally:
176-
if faulthandler_log_path:
177-
faulthandler.disable()
178-
faulthandler_log_file.close()
179-
os.remove(faulthandler_log_path)
180172

181173
send_accumulator_updates(outfile)
182174

@@ -188,9 +180,6 @@ def main(infile: IO, outfile: IO) -> None:
188180
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
189181
sys.exit(-1)
190182

191-
# Force to cancel dump_traceback_later
192-
faulthandler.cancel_dump_traceback_later()
193-
194183

195184
if __name__ == "__main__":
196185
# Read information about how to connect back to the JVM from the environment.

python/pyspark/sql/worker/data_source_pushdown_filters.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
import base64
19-
import faulthandler
2019
import json
2120
import os
2221
import sys
@@ -49,7 +48,12 @@
4948
)
5049
from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
5150
from pyspark.sql.worker.plan_data_source_read import write_read_func_and_partitions
52-
from pyspark.util import handle_worker_exception, local_connect_and_auth
51+
from pyspark.util import (
52+
handle_worker_exception,
53+
local_connect_and_auth,
54+
with_faulthandler,
55+
start_faulthandler_periodic_traceback,
56+
)
5357
from pyspark.worker_util import (
5458
check_python_version,
5559
pickleSer,
@@ -119,6 +123,7 @@ def deserializeFilter(jsonDict: dict) -> Filter:
119123
return filter
120124

121125

126+
@with_faulthandler
122127
def main(infile: IO, outfile: IO) -> None:
123128
"""
124129
Main method for planning a data source read with filter pushdown.
@@ -140,18 +145,10 @@ def main(infile: IO, outfile: IO) -> None:
140145
on the reader and determines which filters are supported. The indices of the supported
141146
filters are sent back to the JVM, along with the list of partitions and the read function.
142147
"""
143-
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
144-
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
145148
try:
146-
if faulthandler_log_path:
147-
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
148-
faulthandler_log_file = open(faulthandler_log_path, "w")
149-
faulthandler.enable(file=faulthandler_log_file)
150-
151149
check_python_version(infile)
152150

153-
if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
154-
faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
151+
start_faulthandler_periodic_traceback()
155152

156153
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
157154
setup_memory_limits(memory_limit_mb)
@@ -258,11 +255,6 @@ def main(infile: IO, outfile: IO) -> None:
258255
except BaseException as e:
259256
handle_worker_exception(e, outfile)
260257
sys.exit(-1)
261-
finally:
262-
if faulthandler_log_path:
263-
faulthandler.disable()
264-
faulthandler_log_file.close()
265-
os.remove(faulthandler_log_path)
266258

267259
send_accumulator_updates(outfile)
268260

@@ -274,9 +266,6 @@ def main(infile: IO, outfile: IO) -> None:
274266
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
275267
sys.exit(-1)
276268

277-
# Force to cancel dump_traceback_later
278-
faulthandler.cancel_dump_traceback_later()
279-
280269

281270
if __name__ == "__main__":
282271
# Read information about how to connect back to the JVM from the environment.

python/pyspark/sql/worker/lookup_data_sources.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
import faulthandler
1817
from importlib import import_module
1918
from pkgutil import iter_modules
2019
import os
@@ -29,7 +28,12 @@
2928
SpecialLengths,
3029
)
3130
from pyspark.sql.datasource import DataSource
32-
from pyspark.util import handle_worker_exception, local_connect_and_auth
31+
from pyspark.util import (
32+
handle_worker_exception,
33+
local_connect_and_auth,
34+
with_faulthandler,
35+
start_faulthandler_periodic_traceback,
36+
)
3337
from pyspark.worker_util import (
3438
check_python_version,
3539
pickleSer,
@@ -40,6 +44,7 @@
4044
)
4145

4246

47+
@with_faulthandler
4348
def main(infile: IO, outfile: IO) -> None:
4449
"""
4550
Main method for looking up the available Python Data Sources in Python path.
@@ -51,18 +56,10 @@ def main(infile: IO, outfile: IO) -> None:
5156
This is responsible for searching the available Python Data Sources so they can be
5257
statically registered automatically.
5358
"""
54-
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
55-
tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
5659
try:
57-
if faulthandler_log_path:
58-
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
59-
faulthandler_log_file = open(faulthandler_log_path, "w")
60-
faulthandler.enable(file=faulthandler_log_file)
61-
6260
check_python_version(infile)
6361

64-
if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
65-
faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
62+
start_faulthandler_periodic_traceback()
6663

6764
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
6865
setup_memory_limits(memory_limit_mb)
@@ -89,11 +86,6 @@ def main(infile: IO, outfile: IO) -> None:
8986
except BaseException as e:
9087
handle_worker_exception(e, outfile)
9188
sys.exit(-1)
92-
finally:
93-
if faulthandler_log_path:
94-
faulthandler.disable()
95-
faulthandler_log_file.close()
96-
os.remove(faulthandler_log_path)
9789

9890
send_accumulator_updates(outfile)
9991

@@ -105,9 +97,6 @@ def main(infile: IO, outfile: IO) -> None:
10597
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
10698
sys.exit(-1)
10799

108-
# Force to cancel dump_traceback_later
109-
faulthandler.cancel_dump_traceback_later()
110-
111100

112101
if __name__ == "__main__":
113102
# Read information about how to connect back to the JVM from the environment.

0 commit comments

Comments
 (0)