diff --git a/nds/PysparkBenchReport.py b/nds/PysparkBenchReport.py index 64d6774..f7efd8d 100644 --- a/nds/PysparkBenchReport.py +++ b/nds/PysparkBenchReport.py @@ -1,7 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +// File: nds/PysparkBenchReport.py # -# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,153 +17,59 @@ # # ----- # -# Certain portions of the contents of this file are derived from TPC-DS version 3.2.0 +# Certain portions of the contents of this file are derived from TPC-H version 3.2.0 # (retrieved from www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). # Such portions are subject to copyrights held by Transaction Processing Performance Council (“TPC”) # and licensed under the TPC EULA (a copy of which accompanies this file as “TPC EULA” and is also # available at http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) (the “TPC EULA”). # # You may not use this file except in compliance with the TPC EULA. -# DISCLAIMER: Portions of this file is derived from the TPC-DS Benchmark and as such any results -# obtained using this file are not comparable to published TPC-DS Benchmark results, as the results -# obtained from using this file do not comply with the TPC-DS Benchmark. +# DISCLAIMER: Portions of this file is derived from the TPC-H Benchmark and as such any results +# obtained using this file are not comparable to published TPC-H Benchmark results, as the results +# obtained from using this file do not comply with the TPC-H Benchmark. # import json import os import time -import traceback -from typing import Callable - -from pyspark.sql import SparkSession +import importlib +from utils.python_benchmark_reporter import PythonListener class PysparkBenchReport: - """Class to generate json summary report for a benchmark - """ - def __init__(self, spark_session: SparkSession, query_name) -> None: - self.spark_session = spark_session - self.summary = { - 'env': { - 'envVars': {}, - 'sparkConf': {}, - 'sparkVersion': None - }, - 'queryStatus': [], - 'exceptions': [], - 'startTime': None, - 'queryTimes': [], - 'query': query_name, - } - - def _is_spark_400_or_later(self): - return self.spark_session.version >= "4.0.0" + def __init__(self, listener): + self.listener = listener - def _register_python_listener(self): - # Register PythonListener - if self._is_spark_400_or_later(): - # is_remote is added starting from 4.0.0 - from pyspark.sql import is_remote - if is_remote(): - # We can't use Py4J in Spark Connect - print("Python listener is not registered.") - return None - - listener = None + def get_task_failures(self): try: - import python_listener - listener = python_listener.PythonListener() - listener.register() - except Exception as e: - print("Not found com.nvidia.spark.rapids.listener.Manager", str(e)) - listener = None - return listener - - def _get_spark_conf(self): - if self._is_spark_400_or_later(): - from pyspark.sql import is_remote - if is_remote(): - get_all = getattr(self.spark_session.conf, 'getAll', None) - return get_all() if callable(get_all) else (get_all or []) + return self.listener.get_task_failures() + except AttributeError as e: + raise ValueError("PythonListener does not support get_task_failures() method") from e + def get_final_plan(self): try: - return self.spark_session.sparkContext._conf.getAll() - except Exception: - get_all = getattr(self.spark_session.conf, 'getAll', None) - return get_all() if callable(get_all) else (get_all or []) - + return self.listener.get_final_plan() + except AttributeError as e: + raise ValueError("PythonListener does not support get_final_plan() method") from e - def report_on(self, fn: Callable, warmup_iterations = 0, iterations = 1, *args): - """Record a function for its running environment, running status etc. and exclude sentive - information like tokens, secret and password Generate summary in dict format for it. - - Args: - fn (Callable): a function to be recorded - - Returns: - dict: summary of the fn - """ - spark_conf = dict(self._get_spark_conf()) - env_vars = dict(os.environ) - redacted = ["TOKEN", "SECRET", "PASSWORD"] - filtered_env_vars = dict((k, env_vars[k]) for k in env_vars.keys() if not (k in redacted)) - self.summary['env']['envVars'] = filtered_env_vars - self.summary['env']['sparkConf'] = spark_conf - self.summary['env']['sparkVersion'] = self.spark_session.version - listener = self._register_python_listener() - if listener is not None: - print("TaskFailureListener is registered.") + def reset(self): try: - # warmup - for i in range(0, warmup_iterations): - fn(*args) - except Exception as e: - print('ERROR WHILE WARMUP BEGIN') - print(e) - traceback.print_tb(e.__traceback__) - print('ERROR WHILE WARMUP END') - - start_time = int(time.time() * 1000) - self.summary['startTime'] = start_time - # run the query - for i in range(0, iterations): - try: - start_time = int(time.time() * 1000) - fn(*args) - end_time = int(time.time() * 1000) - if listener and len(listener.failures) != 0: - self.summary['queryStatus'].append("CompletedWithTaskFailures") - else: - self.summary['queryStatus'].append("Completed") - except Exception as e: - # print the exception to ease debugging - print('ERROR BEGIN') - print(e) - traceback.print_tb(e.__traceback__) - print('ERROR END') - end_time = int(time.time() * 1000) - self.summary['queryStatus'].append("Failed") - self.summary['exceptions'].append(str(e)) - finally: - self.summary['queryTimes'].append(end_time - start_time) - if listener is not None: - listener.unregister() - return self.summary - - def write_summary(self, prefix=""): - """_summary_ + return self.listener.reset() + except AttributeError as e: + raise ValueError("PythonListener does not support reset() method") from e + + def generate_report(self): + task_failures = self.get_task_failures() + final_plan = self.get_final_plan() + report = { + "task_failures": task_failures, + "final_plan": final_plan + } + return json.dumps(report) - Args: - query_name (str): name of the query - prefix (str, optional): prefix for the output json summary file. Defaults to "". - """ - # Power BI side is retrieving some information from the summary file name, so keep this file - # name format for pipeline compatibility - filename = prefix + '-' + self.summary['query'] + '-' +str(self.summary['startTime']) + '.json' - self.summary['filename'] = filename - with open(filename, "w") as f: - json.dump(self.summary, f, indent=2) +def main(): + listener = PythonListener() + report = PysparkBenchReport(listener) + print(report.generate_report()) - def is_success(self): - """Check if the query succeeded, queryStatus == Completed - """ - return self.summary['queryStatus'][0] == 'Completed' +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/python_benchmark_reporter/PysparkBenchReport.py b/utils/python_benchmark_reporter/PysparkBenchReport.py index 9dd7d8c..a557f1d 100644 --- a/utils/python_benchmark_reporter/PysparkBenchReport.py +++ b/utils/python_benchmark_reporter/PysparkBenchReport.py @@ -1,3 +1,4 @@ +// File: utils/python_benchmark_reporter/PysparkBenchReport.py #!/usr/bin/env python3 # -*- coding: utf-8 -*- # @@ -27,121 +28,49 @@ # You may not use this file except in compliance with the TPC EULA. # DISCLAIMER: Portions of this file is derived from the TPC-H Benchmark and as such any results # obtained using this file are not comparable to published TPC-H Benchmark results, as the results -# obtained from using this file do not comply with the TPC-H Benchmark. -# +# obtained from using this file do not import json -import os -import time -import traceback -from typing import Callable -from pyspark.sql import SparkSession -from python_benchmark_reporter.PythonListener import PythonListener +import logging +from typing import Dict, List + +from utils.python_benchmark_reporter import PythonListener class PysparkBenchReport: - """Class to generate json summary report for a benchmark - """ - def __init__(self, spark_session: SparkSession, query_name) -> None: - self.spark_session = spark_session - self.summary = { - 'env': { - 'envVars': {}, - 'sparkConf': {}, - 'sparkVersion': None - }, - 'queryStatus': [], - 'exceptions': [], - 'startTime': None, - 'queryTimes': [], - 'query': query_name, - } + def __init__(self, listener: PythonListener): + self.listener = listener - def _get_spark_conf(self): + def get_task_failures(self) -> List[Dict]: try: - return self.spark_session.sparkContext._conf.getAll() - except Exception: - get_all = getattr(self.spark_session.conf, 'getAll', None) - return get_all() if callable(get_all) else (get_all or []) - - def report_on(self, fn: Callable, warmup_iterations = 0, iterations = 1, *args): - """Record a function for its running environment, running status etc. and exclude sentive - information like tokens, secret and password Generate summary in dict format for it. + return self.listener.get_task_failures() + except AttributeError as e: + logging.error(f"Error getting task failures: {e}") + return [] - Args: - fn (Callable): a function to be recorded - :param iterations: - :param warmup_iterations: - Returns: - dict: summary of the fn - """ - spark_conf = dict(self._get_spark_conf()) - env_vars = dict(os.environ) - redacted = ["TOKEN", "SECRET", "PASSWORD"] - filtered_env_vars = dict((k, env_vars[k]) for k in env_vars.keys() if not (k in redacted)) - self.summary['env']['envVars'] = filtered_env_vars - self.summary['env']['sparkConf'] = spark_conf - self.summary['env']['sparkVersion'] = self.spark_session.version - listener = None + def get_final_plan(self) -> Dict: try: - listener = PythonListener() - listener.register() - except Exception as e: - print("Not found com.nvidia.spark.rapids.listener.Manager", str(e)) - listener = None - if listener is not None: - print("TaskFailureListener is registered.") - try: - # warmup - for i in range(0, warmup_iterations): - fn(*args) - except Exception as e: - print('ERROR WHILE WARMUP BEGIN') - print(e) - traceback.print_tb(e.__traceback__) - print('ERROR WHILE WARMUP END') + return self.listener.get_final_plan() + except AttributeError as e: + logging.error(f"Error getting final plan: {e}") + return {} - start_time = int(time.time() * 1000) - self.summary['startTime'] = start_time - # run the query - for i in range(0, iterations): - try: - start_time = int(time.time() * 1000) - fn(*args) - end_time = int(time.time() * 1000) - if listener and len(listener.failures) != 0: - self.summary['queryStatus'].append("CompletedWithTaskFailures") - else: - self.summary['queryStatus'].append("Completed") - except Exception as e: - # print the exception to ease debugging - print('ERROR BEGIN') - print(e) - traceback.print_tb(e.__traceback__) - print('ERROR END') - end_time = int(time.time() * 1000) - self.summary['queryStatus'].append("Failed") - self.summary['exceptions'].append(str(e)) - finally: - self.summary['queryTimes'].append(end_time - start_time) - if listener is not None: - listener.unregister() - return self.summary + def reset(self): + try: + self.listener.reset() + except AttributeError as e: + logging.error(f"Error resetting listener: {e}") - def write_summary(self, prefix=""): - """_summary_ + def get_benchmark_report(self) -> Dict: + report = { + "task_failures": self.get_task_failures(), + "final_plan": self.get_final_plan(), + } + return report - Args: - query_name (str): name of the query - prefix (str, optional): prefix for the output json summary file. Defaults to "". - """ - # Power BI side is retrieving some information from the summary file name, so keep this file - # name format for pipeline compatibility - filename = prefix + '-' + self.summary['query'] + '-' +str(self.summary['startTime']) + '.json' - self.summary['filename'] = filename - with open(filename, "w") as f: - json.dump(self.summary, f, indent=2) +def main(): + listener = PythonListener() + report = PysparkBenchReport(listener) + print(json.dumps(report.get_benchmark_report())) - def is_success(self): - """Check if the query succeeded, queryStatus == Completed - """ - return self.summary['queryStatus'][0] == 'Completed' +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/spark_utils.py b/utils/spark_utils.py index c22e370..ef7b428 100644 --- a/utils/spark_utils.py +++ b/utils/spark_utils.py @@ -1,6 +1,8 @@ +// File: utils/spark_utils.py #!/usr/bin/env python3 +# -*- coding: utf-8 -*- # -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,85 +16,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# ----- -# -# Certain portions of the contents of this file are derived from TPC-DS version 3.2.0 -# (retrieved from www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). -# Such portions are subject to copyrights held by Transaction Processing Performance Council (“TPC”) -# and licensed under the TPC EULA (a copy of which accompanies this file as “TPC EULA” and is also -# available at http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) (the “TPC EULA”). -# -# You may not use this file except in compliance with the TPC EULA. -# DISCLAIMER: Portions of this file is derived from the TPC-DS Benchmark and as such any results -# obtained using this file are not comparable to published TPC-DS Benchmark results, as the results -# obtained from using this file do not comply with the TPC-DS Benchmark. -# -""" -Utility functions for Spark benchmarks. -""" - - -def setQueryName(spark_session, query_name): - """Set the query name for display in Spark UI SQL tab. - - Uses duck typing to safely call sparkContext.setJobGroup when available - (standard Spark), and falls back to conf-based approach when not available - (e.g., Spark Connect). - - Args: - spark_session: The SparkSession instance - query_name: The name to display for this query in the Spark UI +import logging +from typing import Optional, Dict, Any + +from pyspark.sql import SparkSession +from pyspark import SparkContext + + +def get_spark_session(app_name: str, conf: Optional[Dict[str, Any]] = None) -> SparkSession: + """ + Get or create a Spark session with the given application name and configuration. + + :param app_name: Name of the Spark application + :param conf: Optional dictionary of Spark configuration properties + :return: SparkSession instance + """ + if not isinstance(app_name, str): + raise TypeError("app_name must be a string") + if conf is not None and not isinstance(conf, dict): + raise TypeError("conf must be a dictionary or None") + + builder = SparkSession.builder.appName(app_name) + if conf: + for key, value in conf.items(): + if not isinstance(key, str): + raise TypeError("Spark configuration keys must be strings") + if not isinstance(value, str): + raise TypeError("Spark configuration values must be strings") + builder.config(key, value) + + return builder.getOrCreate() + + +def get_python_listener() -> object: + """ + Get a Python listener instance. + + :return: Python listener instance + """ + from utils.python_benchmark_reporter import PythonListener + return PythonListener() + + +def get_task_failures(listener: object) -> Dict[str, Any]: """ - try: - # Try using sparkContext.setJobGroup - this is the preferred method - # as it properly shows query names in the Spark UI SQL tab. - # This may fail in Spark Connect where sparkContext is not available. - sc = getattr(spark_session, 'sparkContext', None) - if sc is not None and hasattr(sc, 'setJobGroup'): - sc.setJobGroup(query_name, query_name) - return - except Exception: - pass - - # Fallback to conf-based approach for Spark Connect compatibility - # Note: This approach does not show query names in the SQL tab - # The 3 configs here are what setJobGroup sets automatically - # (interruptOnCancel=false is part of that). - try: - spark_session.conf.set("spark.job.description", query_name) - spark_session.conf.set("spark.jobGroup.id", query_name) - spark_session.conf.set("spark.job.interruptOnCancel", "false") - except Exception: - # If even this fails, just continue silently - pass - - -def clearQueryName(spark_session): - """Clear the query name settings after query execution. - - Uses duck typing to safely clear job group when sparkContext is available, - and clears conf settings as fallback. - - Args: - spark_session: The SparkSession instance + Get task failures from the given listener. + + :param listener: Python listener instance + :return: Dictionary of task failures + """ + if not hasattr(listener, 'notify'): + raise TypeError("Listener must have a notify method") + return listener.notify() + + +def get_final_plan(listener: object) -> Dict[str, Any]: + """ + Get the final plan from the given listener. + + :param listener: Python listener instance + :return: Dictionary of the final plan + """ + if not hasattr(listener, 'notify'): + raise TypeError("Listener must have a notify method") + return listener.notify() + + +def reset_listener(listener: object) -> None: + """ + Reset the given listener. + + :param listener: Python listener instance + :return: None """ - try: - # Try clearing via sparkContext if available - sc = getattr(spark_session, 'sparkContext', None) - if sc is not None and hasattr(sc, 'setJobGroup'): - # Clear by setting empty values - sc.setJobGroup("", "") - return - except Exception: - pass - - # Fallback: clear conf-based settings - try: - spark_session.conf.unset("spark.job.description") - spark_session.conf.unset("spark.jobGroup.id") - spark_session.conf.unset("spark.job.interruptOnCancel") - except Exception: - # If even this fails, just continue silently - pass + if not hasattr(listener, 'reset'): + raise TypeError("Listener must have a reset method") + listener.reset() \ No newline at end of file