diff --git a/user_tools/pyproject.toml b/user_tools/pyproject.toml index 8772df3d2..efd7c56aa 100644 --- a/user_tools/pyproject.toml +++ b/user_tools/pyproject.toml @@ -85,6 +85,8 @@ dependencies = [ "scikit-learn==1.7.0", # used for retrieving available memory on the host "psutil==7.0.0", + # used to read zstd-compressed spark event logs + "zstandard==0.25.0", # pyspark for distributed computing "pyspark>=3.5.7,<4.0.0", # Jproperties used to handle Java properties file (added for the Tools API) diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/README.md b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/README.md new file mode 100644 index 000000000..3a17b61c3 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/README.md @@ -0,0 +1,57 @@ +# Event Log Runtime Detector + +This package provides a lightweight Python detector for deciding which full +tools flow should handle a single Spark application event log. + +The detector is an early routing check. It scans a bounded prefix of an event +log, stops as soon as it has enough information, and returns one of: + +- `PROFILING`: a RAPIDS runtime signal was found. +- `QUALIFICATION`: startup properties indicate standard OSS Spark with no + RAPIDS markers. +- `UNKNOWN`: the scan did not reach enough information within the event budget. + +## Detection Flow + +1. `resolver.py` resolves the input into ordered event-log files. + Supported inputs are a single event-log file or an Apache Spark rolling + event-log directory using the `eventlog_v2_*` / `events_*` layout. +2. `stream.py` opens each file through `CspPath.open_input_stream()` and yields + one decoded event-log line at a time. The full log is not loaded into memory. +3. `scanner.py` parses events until a decision is available or the + `max_events_scanned` budget is reached. +4. `classifier.py` classifies the accumulated Spark properties as `SPARK` or + `SPARK_RAPIDS`. +5. `detector.py` maps the scan result to `ToolExecution`. + +## RAPIDS Detection + +RAPIDS logs are detected from either of these signals: + +- `SparkRapidsBuildInfoEvent`, emitted by RAPIDS plugin event logs. +- Spark properties showing `spark.plugins` contains + `com.nvidia.spark.SQLPlugin` and `spark.rapids.sql.enabled` is not `false`. + +The `spark.rapids.sql.enabled` parse matches the Scala tools behavior: +missing or unparseable values default to `true`. + +## CPU Fast Path + +When `SparkListenerEnvironmentUpdate` is reached and startup Spark properties +contain no RAPIDS-related configuration, the detector can return +`QUALIFICATION` immediately. This applies to both single-file and OSS rolling +event logs. + +If RAPIDS-related configuration is present but not decisive, the scanner keeps +reading within the configured event budget. This avoids treating a log as CPU +when later `modifiedConfigs` may make the RAPIDS configuration active. + +## Streaming And Memory + +The detector streams one line at a time. Memory is bounded to: + +- a small set of runtime metadata fields, +- the accumulated Spark properties map, +- the current decoded event record. + +It does not retain raw events or read entire event-log files into memory. diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/__init__.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/__init__.py new file mode 100644 index 000000000..014a8ed64 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lightweight event log runtime detector. + +Scans a bounded prefix of a Spark event log and returns a tool execution +decision (``QUALIFICATION`` / ``PROFILING`` / ``UNKNOWN``) plus best-effort +runtime metadata, without invoking the full tools pipeline. + +Public entry point: :func:`detect_spark_runtime`. +""" + +from .detector import detect_spark_runtime +from .types import ( + DetectionResult, + EventLogDetectionError, + EventLogReadError, + SparkRuntime, + ToolExecution, + UnsupportedCompressionError, + UnsupportedInputError, +) + +__all__ = [ + "DetectionResult", + "EventLogDetectionError", + "EventLogReadError", + "SparkRuntime", + "ToolExecution", + "UnsupportedCompressionError", + "UnsupportedInputError", + "detect_spark_runtime", +] diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/classifier.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/classifier.py new file mode 100644 index 000000000..d27a831ab --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/classifier.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classify Spark runtime from accumulated Spark properties. + +The scanner extracts Spark properties from event-log records and passes the +merged map to this module. This module only answers whether those properties +represent standard Spark or a RAPIDS-enabled application. +""" + +from typing import Mapping + +from spark_rapids_tools.tools.eventlog_detector import markers as m +from spark_rapids_tools.tools.eventlog_detector.types import SparkRuntime + + +def _parse_bool(raw: str, default: bool) -> bool: + """Parse Spark boolean strings with Scala-compatible fallback behavior. + + Scala's ``String.toBoolean`` accepts only ``true`` and ``false``. The + Scala tools wrap that parse in ``Try(...).getOrElse(default)``, so values + such as ``yes``, ``1``, or an empty string must return ``default`` rather + than using Python truthiness. + """ + stripped = raw.strip().lower() + if stripped == "true": + return True + if stripped == "false": + return False + return default + + +def _is_spark_rapids(props: Mapping[str, str]) -> bool: + """Return true when Spark properties show the RAPIDS SQL plugin is active.""" + plugins = props.get(m.GPU_PLUGIN_KEY, "") + if m.GPU_PLUGIN_CLASS_SUBSTRING not in plugins: + return False + raw = props.get(m.GPU_ENABLED_KEY) + if raw is None: + return m.GPU_ENABLED_DEFAULT + return _parse_bool(raw, default=m.GPU_ENABLED_DEFAULT) + + +def has_rapids_conf_markers(props: Mapping[str, str]) -> bool: + """Return true when properties contain any RAPIDS-related configuration. + + This is intentionally broader than ``_is_spark_rapids``. A disabled or + incomplete RAPIDS configuration is not classified as RAPIDS, but its + presence should prevent early CPU routing because later events may update + the effective configuration. + """ + if m.GPU_PLUGIN_CLASS_SUBSTRING in props.get(m.GPU_PLUGIN_KEY, ""): + return True + return m.GPU_ENABLED_KEY in props + + +def classify_runtime(props: Mapping[str, str]) -> SparkRuntime: + """Classify accumulated Spark properties into the supported runtime enum.""" + if _is_spark_rapids(props): + return SparkRuntime.SPARK_RAPIDS + return SparkRuntime.SPARK diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/detector.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/detector.py new file mode 100644 index 000000000..bcf4de1ae --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/detector.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Top-level event log runtime detector.""" + +from typing import Optional, Union + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector.classifier import classify_runtime +from spark_rapids_tools.tools.eventlog_detector.resolver import resolve_event_log_files +from spark_rapids_tools.tools.eventlog_detector.scanner import scan_events_across +from spark_rapids_tools.tools.eventlog_detector.types import ( + DetectionResult, + SparkRuntime, + Termination, + ToolExecution, +) + + +def detect_spark_runtime( + event_log: Union[str, CspPath], + *, + max_events_scanned: int = 500, + allow_cpu_fast_path: bool = True, +) -> DetectionResult: + """Classify a single-app event log into a tool execution decision. + + Returns ``PROFILING`` when a RAPIDS marker is found, ``QUALIFICATION`` when + the log appears to be OSS Spark/CPU, and ``UNKNOWN`` when the bounded scan + cannot make a decision. + + ``max_events_scanned`` caps CPU/IO cost. Logs that do not expose a RAPIDS + marker or ``SparkListenerEnvironmentUpdate`` within the cap remain + ``UNKNOWN``. + + ``allow_cpu_fast_path`` enables early CPU routing when startup properties + contain no RAPIDS markers. Disable it to require EOF before returning + ``QUALIFICATION``. + """ + # Keep the caller's input verbatim in source_path (cloud URI schemes + # would otherwise be stripped by CspPath normalisation). + source_path = event_log if isinstance(event_log, str) else str(event_log) + path = event_log if isinstance(event_log, CspPath) else CspPath(str(event_log)) + _, files = resolve_event_log_files(path) + + scan = scan_events_across( + files, + budget=max_events_scanned, + allow_cpu_fast_path=allow_cpu_fast_path, + ) + + runtime: Optional[SparkRuntime] + if scan.rapids_build_info_seen: + runtime = SparkRuntime.SPARK_RAPIDS + elif scan.env_update_seen: + runtime = classify_runtime(scan.spark_properties) + else: + runtime = None + + if runtime is SparkRuntime.SPARK_RAPIDS: + tool_execution = ToolExecution.PROFILING + reason = f"decisive: classified as {runtime.value}" + elif scan.termination is Termination.CPU_FAST_PATH and runtime is SparkRuntime.SPARK: + tool_execution = ToolExecution.QUALIFICATION + reason = "startup properties classify as SPARK with no RAPIDS markers" + elif scan.termination is Termination.EXHAUSTED and scan.env_update_seen: + tool_execution = ToolExecution.QUALIFICATION + reason = "walked full log, no RAPIDS signal" + else: + tool_execution = ToolExecution.UNKNOWN + reason = ( + "no decisive signal within bounded scan" + if scan.env_update_seen + else "no SparkListenerEnvironmentUpdate reached" + ) + + resolved_path = scan.last_scanned_path or (str(files[0]) if files else source_path) + return DetectionResult( + tool_execution=tool_execution, + spark_runtime=runtime, + app_id=scan.app_id, + spark_version=scan.spark_version, + event_log_path=resolved_path, + source_path=source_path, + reason=reason, + ) diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/markers.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/markers.py new file mode 100644 index 000000000..176e73a68 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/markers.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Property keys and constants used by the runtime detector. + +Each block carries a Scala source reference so the two implementations +can be kept in sync when the Scala detection rules change. +""" + +# GPU (SPARK_RAPIDS) markers. +# Scala: org/apache/spark/sql/rapids/tool/ToolUtils.scala :: isPluginEnabled +GPU_PLUGIN_KEY: str = "spark.plugins" +GPU_PLUGIN_CLASS_SUBSTRING: str = "com.nvidia.spark.SQLPlugin" +GPU_ENABLED_KEY: str = "spark.rapids.sql.enabled" +# Defaults to true when missing or unparseable. +GPU_ENABLED_DEFAULT: bool = True + +# RAPIDS 24.06+ plugin marker. +# Scala: com/nvidia/spark/rapids/SparkRapidsBuildInfoEvent.scala +EVENT_SPARK_RAPIDS_BUILD_INFO: str = "com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent" +EVENT_SPARK_RAPIDS_BUILD_INFO_SHORTNAME: str = "SparkRapidsBuildInfoEvent" + +# Apache Spark rolling event-log directory layout. +# Scala: com/nvidia/spark/rapids/tool/EventLogPathProcessor.scala :: isEventLogDir +OSS_EVENT_LOG_DIR_PREFIX: str = "eventlog_v2_" +OSS_EVENT_LOG_FILE_PREFIX: str = "events_" + +# Spark listener event names consumed by the scanner. +EVENT_LOG_START: str = "SparkListenerLogStart" +EVENT_APPLICATION_START: str = "SparkListenerApplicationStart" +EVENT_ENVIRONMENT_UPDATE: str = "SparkListenerEnvironmentUpdate" +EVENT_SQL_EXECUTION_START: str = "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" +# Unqualified event name accepted by the scanner for compatibility. +EVENT_SQL_EXECUTION_START_SHORTNAME: str = "SparkListenerSQLExecutionStart" diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/resolver.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/resolver.py new file mode 100644 index 000000000..6b45de710 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/resolver.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Input-path resolver for single files and Apache Spark rolling event logs.""" + +import re +from typing import List, Optional, Tuple + +from spark_rapids_tools.storagelib import CspFs, CspPath +from spark_rapids_tools.tools.eventlog_detector import markers as m +from spark_rapids_tools.tools.eventlog_detector.types import UnsupportedInputError + + +_OSS_EVENT_FILE_PATTERN = re.compile(r"^events_(\d+)_.*") + + +def parse_oss_event_file_index(name: str) -> Optional[int]: + """Return the numeric chunk index from ``events__...`` files.""" + match = _OSS_EVENT_FILE_PATTERN.match(name) + if match is None: + return None + return int(match.group(1)) + + +def _is_oss_event_log_file(path: CspPath) -> bool: + return parse_oss_event_file_index(path.base_name()) is not None + + +def _base_name_from_source(source: str) -> str: + """Return the final path component, ignoring trailing separators.""" + return source.rstrip("/").rsplit("/", 1)[-1] + + +def resolve_event_log_files(path: CspPath) -> Tuple[str, List[CspPath]]: + """Resolve ``path`` to an ordered list of files to scan. + + Supported inputs are a single concrete file or an Apache Spark rolling + event-log directory named ``eventlog_v2_*``. Other directory layouts are + rejected so callers can use the full tools pipeline. + """ + source = path.no_scheme + + if path.is_file(): + return source, [path] + + if not path.is_dir(): + raise UnsupportedInputError( + f"Path is neither a file nor a directory: {source}" + ) + + if not _base_name_from_source(source).startswith(m.OSS_EVENT_LOG_DIR_PREFIX): + raise UnsupportedInputError( + f"Directory {source} is not a supported input shape. Only single " + f"files and Apache Spark rolling event-log directories named " + f"{m.OSS_EVENT_LOG_DIR_PREFIX}* are handled " + "here; use the full pipeline for other shapes." + ) + + event_files = [c for c in CspFs.list_all_files(path) if _is_oss_event_log_file(c)] + if not event_files: + raise UnsupportedInputError(f"Directory {source} does not contain Spark event chunks") + + event_files.sort( + key=lambda f: ( + parse_oss_event_file_index(f.base_name()) or 0, + f.base_name(), + ) + ) + return source, event_files diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/scanner.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/scanner.py new file mode 100644 index 000000000..113e681a4 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/scanner.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bounded streaming event scanner. + +Walks JSON-per-line event logs under a shared event budget and accumulates the +startup and per-SQL properties required for runtime classification. +""" + +import json +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Optional + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector import markers as m +from spark_rapids_tools.tools.eventlog_detector.classifier import ( + classify_runtime, + has_rapids_conf_markers, +) +from spark_rapids_tools.tools.eventlog_detector.stream import open_event_log_stream +from spark_rapids_tools.tools.eventlog_detector.types import SparkRuntime, Termination + + +@dataclass +class _ScanResult: + spark_properties: Dict[str, str] = field(default_factory=dict) + app_id: Optional[str] = None + spark_version: Optional[str] = None + env_update_seen: bool = False + rapids_build_info_seen: bool = False + events_scanned: int = 0 + termination: Termination = Termination.EXHAUSTED + last_scanned_path: Optional[str] = None + + +def scan_events( + lines: Iterable[str], + *, + budget: int, + allow_cpu_fast_path: bool = True, + state: Optional[_ScanResult] = None, +) -> _ScanResult: + """Scan one stream of lines, optionally continuing from a prior state. + + Terminates as ``DECISIVE`` on the first non-SPARK classification, + ``CPU_FAST_PATH`` after plain Spark startup properties, ``CAP_HIT`` when + ``budget`` is exhausted, or ``EXHAUSTED`` when the iterator runs out. + """ + result = state if state is not None else _ScanResult() + + for raw in lines: + if result.events_scanned >= budget: + result.termination = Termination.CAP_HIT + return result + + if not raw: + continue + + try: + event = json.loads(raw) + except (json.JSONDecodeError, ValueError): + # Tolerate trailing partial lines in live logs; count them so + # a pathological log can't keep us scanning forever. + result.events_scanned += 1 + continue + + result.events_scanned += 1 + name = event.get("Event") + if name in ( + m.EVENT_SPARK_RAPIDS_BUILD_INFO, + m.EVENT_SPARK_RAPIDS_BUILD_INFO_SHORTNAME, + ): + result.rapids_build_info_seen = True + result.termination = Termination.DECISIVE + return result + if name == m.EVENT_LOG_START: + version = event.get("Spark Version") + if isinstance(version, str): + result.spark_version = version + elif name == m.EVENT_APPLICATION_START: + app_id = event.get("App ID") + if isinstance(app_id, str): + result.app_id = app_id + elif name == m.EVENT_ENVIRONMENT_UPDATE: + props = event.get("Spark Properties") or {} + if isinstance(props, dict): + for k, v in props.items(): + if isinstance(k, str) and isinstance(v, str): + result.spark_properties[k] = v + result.env_update_seen = True + runtime = classify_runtime(result.spark_properties) + if runtime is not SparkRuntime.SPARK: + result.termination = Termination.DECISIVE + return result + if allow_cpu_fast_path and not has_rapids_conf_markers(result.spark_properties): + result.termination = Termination.CPU_FAST_PATH + return result + elif name in (m.EVENT_SQL_EXECUTION_START, m.EVENT_SQL_EXECUTION_START_SHORTNAME): + modified = event.get("modifiedConfigs") or {} + if isinstance(modified, dict) and modified: + for k, v in modified.items(): + if isinstance(k, str) and isinstance(v, str): + result.spark_properties[k] = v + # Per-query configs refine startup properties; without env-update + # context they are not enough to classify the whole event log. + if result.env_update_seen and ( + classify_runtime(result.spark_properties) is not SparkRuntime.SPARK + ): + result.termination = Termination.DECISIVE + return result + + result.termination = Termination.EXHAUSTED + return result + + +def scan_events_across( + files: List[CspPath], + *, + budget: int, + allow_cpu_fast_path: bool = True, +) -> _ScanResult: + """Walk ``files`` in order under a single shared ``budget``.""" + state = _ScanResult() + for path in files: + if state.events_scanned >= budget: + state.termination = Termination.CAP_HIT + return state + state.last_scanned_path = str(path) + with open_event_log_stream(path) as lines: + state = scan_events( + lines, + budget=budget, + allow_cpu_fast_path=allow_cpu_fast_path, + state=state, + ) + if state.termination in ( + Termination.DECISIVE, + Termination.CPU_FAST_PATH, + Termination.CAP_HIT, + ): + return state + state.termination = Termination.EXHAUSTED + return state diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/stream.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/stream.py new file mode 100644 index 000000000..719e994dc --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/stream.py @@ -0,0 +1,102 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Codec-aware context-managed line streamer for Spark event logs. + +Opens the file through ``CspPath.open_input_stream()``, wraps it with +the right decompression and text layers, and yields an +``Iterator[str]``. Streaming only — the full file is never buffered. + +``CspPath.open_input_stream()`` delegates to PyArrow, which auto-detects +and decompresses ``.gz`` and ``.zst`` transparently. ``.zstd`` is not +recognised by PyArrow, so this module decompresses it via ``zstandard``. +""" + +import contextlib +import io +from typing import Iterator + +import zstandard as zstd + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector.types import ( + EventLogReadError, + UnsupportedCompressionError, +) + + +# Suffixes PyArrow already decompresses for us. +_PYARROW_AUTO_DECOMP_SUFFIXES = {".gz", ".zst"} +# Suffixes we must decompress manually via zstandard. +_ZSTD_MANUAL_SUFFIXES = {".zstd"} +# Suffixes treated as plain text. +_PLAIN_SUFFIXES = {"", ".inprogress"} +# Whitelist of accepted suffixes; anything else raises +# ``UnsupportedCompressionError``. +_SUPPORTED_SUFFIXES = ( + _PYARROW_AUTO_DECOMP_SUFFIXES | _ZSTD_MANUAL_SUFFIXES | _PLAIN_SUFFIXES +) + + +def _classify_suffix(path: CspPath) -> str: + name = path.base_name().lower() + dot = name.rfind(".") + if dot < 0: + return "" + return name[dot:] + + +@contextlib.contextmanager +def open_event_log_stream(path: CspPath) -> Iterator[Iterator[str]]: + suffix = _classify_suffix(path) + if suffix not in _SUPPORTED_SUFFIXES: + raise UnsupportedCompressionError( + f"File suffix '{suffix}' is not supported. " + "Supported: plain, .inprogress, .gz, .zstd, .zst." + ) + + try: + byte_stream = path.open_input_stream() + except Exception as exc: + raise EventLogReadError(f"Failed to open event log {path}: {exc}") from exc + + close_stack = contextlib.ExitStack() + close_stack.callback(byte_stream.close) + try: + if suffix in _ZSTD_MANUAL_SUFFIXES: + # Decompress ``.zstd`` ourselves; PyArrow does not handle it. + dctx = zstd.ZstdDecompressor() + decompressed: io.RawIOBase = dctx.stream_reader(byte_stream) + close_stack.callback(decompressed.close) + else: + # Plain text, or already decompressed by PyArrow. + decompressed = byte_stream + + text = io.TextIOWrapper(decompressed, encoding="utf-8", errors="replace", newline="") + close_stack.callback(text.close) + + def line_iter() -> Iterator[str]: + # One event per line; strip the trailing newline and leave + # empty lines for the caller to skip. + for raw in text: + yield raw.rstrip("\r\n") + + try: + yield line_iter() + except OSError as exc: + # Only reclassify real I/O failures; let caller-side logic + # errors bubble up untouched. + raise EventLogReadError(f"Error reading event log {path}: {exc}") from exc + finally: + close_stack.close() diff --git a/user_tools/src/spark_rapids_tools/tools/eventlog_detector/types.py b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/types.py new file mode 100644 index 000000000..01d72eaf5 --- /dev/null +++ b/user_tools/src/spark_rapids_tools/tools/eventlog_detector/types.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Types, enums, and exceptions for the event log runtime detector.""" + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class ToolExecution(str, Enum): + """Tool execution decision returned to the caller.""" + + QUALIFICATION = "QUALIFICATION" + PROFILING = "PROFILING" + UNKNOWN = "UNKNOWN" + + +class SparkRuntime(str, Enum): + """Runtime taxonomy. + + Values mirror ``org.apache.spark.sql.rapids.tool.util.SparkRuntime`` + in the Scala core so string comparisons against existing pipelines + keep working. + """ + + SPARK = "SPARK" + SPARK_RAPIDS = "SPARK_RAPIDS" + + +class Termination(Enum): + """How the scanner stopped.""" + + DECISIVE = "DECISIVE" # classification returned non-SPARK + CPU_FAST_PATH = "CPU_FAST_PATH" # stopped after plain-SPARK startup props + EXHAUSTED = "EXHAUSTED" # walked every file to EOF under the budget + CAP_HIT = "CAP_HIT" # hit max_events_scanned before exhausting files + + +@dataclass(frozen=True) +class DetectionResult: + """Result returned by :func:`detect_spark_runtime`. + + ``spark_runtime`` is best-effort metadata and may be ``None`` when + ``tool_execution`` is ``UNKNOWN``. + """ + + tool_execution: ToolExecution + spark_runtime: Optional[SparkRuntime] + app_id: Optional[str] + spark_version: Optional[str] + event_log_path: str + source_path: str + reason: str + + +class EventLogDetectionError(Exception): + """Base class for detector errors.""" + + +class UnsupportedInputError(EventLogDetectionError): + """Input shape is not handled (multi-app dir, wildcard, comma list, ...).""" + + +class UnsupportedCompressionError(EventLogDetectionError): + """File uses a compression codec the detector does not handle.""" + + +class EventLogReadError(EventLogDetectionError): + """Wraps an I/O failure while reading the event log.""" diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/__init__.py b/user_tools/tests/spark_rapids_tools_ut/tools/__init__.py new file mode 100644 index 000000000..51b351e5a --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""init file of the tools unit-tests package""" diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/__init__.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/__init__.py new file mode 100644 index 000000000..7088b6bc7 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""init file of the eventlog_detector unit-tests package""" diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_classifier.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_classifier.py new file mode 100644 index 000000000..61ca65891 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_classifier.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``eventlog_detector.classifier``.""" +# pylint: disable=too-few-public-methods # test classes naturally have few methods + +import pytest + +from spark_rapids_tools.tools.eventlog_detector.classifier import ( + classify_runtime, + has_rapids_conf_markers, +) +from spark_rapids_tools.tools.eventlog_detector.types import SparkRuntime + + +class TestEmptyProperties: + """Test classification with an empty properties dict.""" + + def test_empty_props_is_spark(self): + assert classify_runtime({}) is SparkRuntime.SPARK + + +class TestRapidsConfigMarkers: + """Test marker presence checks used by the CPU fast path.""" + + @pytest.mark.parametrize( + "props,expected", + [ + ({"spark.master": "local"}, False), + ({"spark.rapids.sql.enabled": "true"}, True), + ({"spark.plugins": "com.nvidia.spark.SQLPlugin"}, True), + ({"spark.plugins": "foo,com.nvidia.spark.SQLPlugin"}, True), + ({ + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "true", + }, True), + ({"spark.plugins": "foo.nvidia.spark.SQLPlugin"}, False), + ], + ) + def test_detects_rapids_marker_configs(self, props, expected): + assert has_rapids_conf_markers(props) is expected + + +class TestSparkRapids: + """Test SPARK_RAPIDS classification logic.""" + + def test_plugin_and_default_enabled(self): + props = {"spark.plugins": "foo,com.nvidia.spark.SQLPlugin,bar"} + assert classify_runtime(props) is SparkRuntime.SPARK_RAPIDS + + def test_plugin_with_enabled_true(self): + props = { + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "true", + } + assert classify_runtime(props) is SparkRuntime.SPARK_RAPIDS + + def test_plugin_with_enabled_false_demotes_to_spark(self): + props = { + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "false", + } + assert classify_runtime(props) is SparkRuntime.SPARK + + def test_enabled_true_without_plugin_is_still_spark(self): + props = {"spark.rapids.sql.enabled": "true"} + assert classify_runtime(props) is SparkRuntime.SPARK + + @pytest.mark.parametrize("bogus_value", ["no", "0", "yes", "1", "", "maybe", "not-a-bool"]) + def test_non_toboolean_values_default_to_true_matching_scala(self, bogus_value): + # Scala: Try { "no".toBoolean }.getOrElse(true) == true because + # "no" is not parseable. The Python classifier must do the same. + props = { + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": bogus_value, + } + assert classify_runtime(props) is SparkRuntime.SPARK_RAPIDS diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector.py new file mode 100644 index 000000000..54400ee8a --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector.py @@ -0,0 +1,221 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for ``eventlog_detector.detect_spark_runtime``.""" +# pylint: disable=too-few-public-methods # test classes naturally have few methods + +import json +from pathlib import Path + +import pytest + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector import detect_spark_runtime +from spark_rapids_tools.tools.eventlog_detector.types import ( + SparkRuntime, + ToolExecution, + UnsupportedInputError, +) + + +def env_update(props: dict) -> dict: + return { + "Event": "SparkListenerEnvironmentUpdate", + "Spark Properties": props, + "System Properties": {}, + "Classpath Entries": {}, + "JVM Information": {}, + } + + +def build_info() -> dict: + return { + "Event": "com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent", + "sparkRapidsBuildInfo": {"version": "24.06.0"}, + "sparkRapidsJniBuildInfo": {}, + "cudfBuildInfo": {}, + "sparkRapidsPrivateBuildInfo": {}, + } + + +def sql_exec_start(modified_configs: dict) -> dict: + return { + "Event": "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", + "executionId": 0, + "description": "", + "details": "", + "physicalPlanDescription": "", + "sparkPlanInfo": {}, + "time": 0, + "modifiedConfigs": modified_configs, + } + + +def _write_plain_log(path: Path, events: list) -> None: + path.write_text( + "\n".join(json.dumps(e) for e in events) + "\n", encoding="utf-8" + ) + + +class TestAcceptsStringPath: + """Test that detect_spark_runtime accepts plain string paths.""" + + def test_str_input_resolves(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [ + {"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}, + {"Event": "SparkListenerApplicationStart", "App ID": "a", "App Name": "A"}, + env_update({"spark.master": "local"}), + ], + ) + result = detect_spark_runtime(str(log)) + assert result.tool_execution is ToolExecution.QUALIFICATION + assert result.spark_runtime is SparkRuntime.SPARK + + +class TestRapidsLog: + """Test detection on RAPIDS event logs.""" + + def test_build_info_event_classifies_as_profiling(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [ + {"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}, + build_info(), + ], + ) + result = detect_spark_runtime(CspPath(str(log))) + assert result.tool_execution is ToolExecution.PROFILING + assert result.spark_runtime is SparkRuntime.SPARK_RAPIDS + assert result.spark_version == "3.5.1" + + def test_env_update_with_plugin_classifies_as_profiling(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [ + {"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}, + {"Event": "SparkListenerApplicationStart", "App ID": "g", "App Name": "G"}, + env_update({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ], + ) + result = detect_spark_runtime(CspPath(str(log))) + assert result.tool_execution is ToolExecution.PROFILING + assert result.spark_runtime is SparkRuntime.SPARK_RAPIDS + assert result.app_id == "g" + + +class TestCpuFastPath: + """Test the default fast path for startup properties that look like plain Spark.""" + + def test_env_update_without_rapids_markers_returns_qualification(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [ + {"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}, + {"Event": "SparkListenerApplicationStart", "App ID": "c", "App Name": "C"}, + env_update({"spark.master": "local"}), + sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ], + ) + result = detect_spark_runtime(str(log)) + assert result.tool_execution is ToolExecution.QUALIFICATION + assert result.spark_runtime is SparkRuntime.SPARK + assert "startup properties" in result.reason.lower() + + def test_fast_path_can_be_disabled(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [ + env_update({"spark.master": "local"}), + sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ], + ) + result = detect_spark_runtime(str(log), allow_cpu_fast_path=False) + assert result.tool_execution is ToolExecution.PROFILING + assert result.spark_runtime is SparkRuntime.SPARK_RAPIDS + + +class TestCapHit: + """Test detection when the event budget is exhausted before env-update.""" + + def test_no_env_update_before_cap_is_unknown(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log( + log, + [{"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}] * 10, + ) + result = detect_spark_runtime(str(log), max_events_scanned=5) + assert result.tool_execution is ToolExecution.UNKNOWN + assert result.spark_runtime is None + + +class TestOssRolling: + """Test detection on Apache Spark rolling event-log directories.""" + + def test_rapids_signal_in_later_rolled_file(self, tmp_path): + d = tmp_path / "eventlog_v2_app-1" + d.mkdir() + _write_plain_log( + d / "events_1_app-1", + [ + {"Event": "SparkListenerLogStart", "Spark Version": "3.5.1"}, + {"Event": "SparkListenerApplicationStart", "App ID": "d", "App Name": "D"}, + env_update({"spark.rapids.sql.enabled": "false"}), + ], + ) + _write_plain_log(d / "events_2_app-1", [build_info()]) + result = detect_spark_runtime(CspPath(str(d))) + assert result.tool_execution is ToolExecution.PROFILING + assert result.spark_runtime is SparkRuntime.SPARK_RAPIDS + assert result.event_log_path.endswith("/events_2_app-1") + + def test_cpu_fast_path_applies_to_rolling_dir(self, tmp_path): + d = tmp_path / "eventlog_v2_app-1" + d.mkdir() + _write_plain_log(d / "events_1_app-1", [env_update({"spark.master": "local"})]) + _write_plain_log(d / "events_2_app-1", [sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"})]) + result = detect_spark_runtime(CspPath(str(d))) + assert result.tool_execution is ToolExecution.QUALIFICATION + assert result.spark_runtime is SparkRuntime.SPARK + assert result.event_log_path.endswith("/events_1_app-1") + assert "startup properties" in result.reason.lower() + + +class TestUnsupportedInput: + """Test that unsupported input shapes raise the expected error.""" + + def test_non_oss_rolling_dir_raises(self, tmp_path): + d = tmp_path / "non_oss_rolling" + d.mkdir() + (d / "eventlog").write_bytes(b"") + (d / "eventlog-2021-06-14--18-00.gz").write_bytes(b"") + with pytest.raises(UnsupportedInputError): + detect_spark_runtime(CspPath(str(d))) + + +class TestSourcePathPreserved: + """Test that source_path echoes the original input string.""" + + def test_source_path_equals_input_string(self, tmp_path): + log = tmp_path / "eventlog" + _write_plain_log(log, [env_update({"spark.master": "local"})]) + input_str = str(log) + result = detect_spark_runtime(input_str) + assert result.source_path == input_str diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector_fixtures.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector_fixtures.py new file mode 100644 index 000000000..97e90afd6 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_detector_fixtures.py @@ -0,0 +1,64 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Anchor tests against fixtures already shipped in the Scala core. + +These are not a full parity sweep. They catch regressions on a small +curated set covering each decisive execution decision. +""" + +from pathlib import Path + +import pytest + +from spark_rapids_tools.tools.eventlog_detector import detect_spark_runtime +from spark_rapids_tools.tools.eventlog_detector.types import SparkRuntime, ToolExecution + + +REPO_ROOT = Path(__file__).resolve().parents[5] +CORE_FIXTURES = REPO_ROOT / "core" / "src" / "test" / "resources" + + +@pytest.mark.parametrize( + "relative_path,expected_execution,expected_runtime", + [ + ( + "spark-events-profiling/eventlog-gpu-dsv2.zstd", + ToolExecution.PROFILING, + SparkRuntime.SPARK_RAPIDS, + ), + ( + # The Profiling tool can process CPU logs; this fixture lives under + # profiling resources but has no RAPIDS runtime markers. + "spark-events-profiling/eventlog_dsv2.zstd", + ToolExecution.QUALIFICATION, + SparkRuntime.SPARK, + ), + ( + "spark-events-qualification/eventlog_same_app_id_1.zstd", + ToolExecution.QUALIFICATION, + SparkRuntime.SPARK, + ), + ], +) +def test_detector_matches_expected_execution_on_scala_fixture( + relative_path: str, expected_execution: ToolExecution, expected_runtime: SparkRuntime +) -> None: + fixture = CORE_FIXTURES / relative_path + if not fixture.exists(): + pytest.skip(f"fixture not available: {fixture}") + # Fixtures are ~small; a generous budget keeps this test decisive. + result = detect_spark_runtime(str(fixture), max_events_scanned=5000) + assert result.tool_execution is expected_execution, result.reason + assert result.spark_runtime is expected_runtime, result.reason diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_resolver.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_resolver.py new file mode 100644 index 000000000..b09c70378 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_resolver.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``eventlog_detector.resolver``.""" +# pylint: disable=too-few-public-methods # test classes naturally have few methods + +from pathlib import Path + +import pytest + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector.resolver import ( + parse_oss_event_file_index, + resolve_event_log_files, +) +from spark_rapids_tools.tools.eventlog_detector.types import UnsupportedInputError + + +class TestOssEventFileIndex: + """Test Apache Spark rolling event file index parsing.""" + + def test_events_file_index_parses(self): + assert parse_oss_event_file_index("events_1_app-1.zstd") == 1 + assert parse_oss_event_file_index("events_10_app-1") == 10 + + def test_non_events_file_returns_none(self): + assert parse_oss_event_file_index("appstatus_app-1.inprogress") is None + assert parse_oss_event_file_index("eventlog") is None + + +class TestResolveSingleFile: + """Test resolving a single event log file.""" + + def test_single_file_returns_single_element_list(self, tmp_path: Path): + f = tmp_path / "eventlog.zstd" + f.write_bytes(b"x") + source, files = resolve_event_log_files(CspPath(str(f))) + assert source == str(f) + assert [p.base_name() for p in files] == ["eventlog.zstd"] + + +class TestResolveOssRollingDir: + """Test resolving an Apache Spark rolling event-log directory.""" + + def test_orders_event_chunks_by_numeric_index(self, tmp_path: Path): + d = tmp_path / "eventlog_v2_app-1" + d.mkdir() + (d / "events_10_app-1.zstd").write_bytes(b"") + (d / "events_2_app-1.zstd").write_bytes(b"") + (d / "events_1_app-1.zstd").write_bytes(b"") + (d / "appstatus_app-1.inprogress").write_bytes(b"") + source, files = resolve_event_log_files(CspPath(str(d))) + assert source == str(d) + assert [p.base_name() for p in files] == [ + "events_1_app-1.zstd", + "events_2_app-1.zstd", + "events_10_app-1.zstd", + ] + + def test_accepts_trailing_slash_on_rolling_dir(self, tmp_path: Path): + d = tmp_path / "eventlog_v2_app-1" + d.mkdir() + (d / "events_1_app-1.zstd").write_bytes(b"") + source, files = resolve_event_log_files(CspPath(f"{d}/")) + assert source.rstrip("/") == str(d) + assert [p.base_name() for p in files] == ["events_1_app-1.zstd"] + + def test_empty_oss_rolling_dir_raises(self, tmp_path: Path): + d = tmp_path / "eventlog_v2_app-1" + d.mkdir() + (d / "appstatus_app-1.inprogress").write_bytes(b"") + with pytest.raises(UnsupportedInputError): + resolve_event_log_files(CspPath(str(d))) + + +class TestResolveUnsupportedShapes: + """Test that unsupported directory shapes raise UnsupportedInputError.""" + + def test_non_oss_rolling_dir_raises(self, tmp_path: Path): + d = tmp_path / "non_oss_rolling" + d.mkdir() + (d / "eventlog-2021-06-14--18-00.gz").write_bytes(b"") + (d / "eventlog").write_bytes(b"") + with pytest.raises(UnsupportedInputError, match="eventlog_v2_\\*"): + resolve_event_log_files(CspPath(str(d))) diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_scanner.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_scanner.py new file mode 100644 index 000000000..a30e0d8f7 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_scanner.py @@ -0,0 +1,254 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``eventlog_detector.scanner``.""" + +import json +from pathlib import Path +from typing import List + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector.scanner import ( + scan_events, + scan_events_across, +) +from spark_rapids_tools.tools.eventlog_detector.types import Termination + + +def env_update(props: dict) -> str: + return json.dumps( + { + "Event": "SparkListenerEnvironmentUpdate", + "Spark Properties": props, + "System Properties": {}, + "Classpath Entries": {}, + "JVM Information": {}, + } + ) + + +def log_start(version: str = "3.5.1") -> str: + return json.dumps({"Event": "SparkListenerLogStart", "Spark Version": version}) + + +def app_start(app_id: str = "app-1", app_name: str = "App") -> str: + return json.dumps( + { + "Event": "SparkListenerApplicationStart", + "App ID": app_id, + "App Name": app_name, + } + ) + + +def rapids_build_info() -> str: + return json.dumps( + { + "Event": "com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent", + "sparkRapidsBuildInfo": {"version": "24.06.0"}, + "sparkRapidsJniBuildInfo": {}, + "cudfBuildInfo": {}, + "sparkRapidsPrivateBuildInfo": {}, + } + ) + + +def sql_exec_start(modified_configs: dict) -> str: + return json.dumps( + { + "Event": "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", + "executionId": 0, + "description": "x", + "details": "", + "physicalPlanDescription": "", + "sparkPlanInfo": {}, + "time": 0, + "modifiedConfigs": modified_configs, + } + ) + + +class TestScanEvents: + """Tests for scan_events scanning a single event stream.""" + + def test_build_info_event_is_decisive_before_env_update(self): + lines = iter([log_start(), rapids_build_info(), app_start()]) + result = scan_events(lines, budget=100) + assert result.rapids_build_info_seen is True + assert result.termination is Termination.DECISIVE + assert result.events_scanned == 2 + + def test_env_update_with_gpu_is_decisive(self): + # Scala tools and the RAPIDS plugin both default spark.rapids.sql.enabled to true, + # so the plugin marker alone is enough to classify the runtime as RAPIDS. + lines = iter( + [ + log_start(), + app_start(), + env_update({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ] + ) + result = scan_events(lines, budget=100) + assert result.env_update_seen is True + assert result.app_id == "app-1" + assert result.spark_version == "3.5.1" + assert result.termination is Termination.DECISIVE + + def test_cpu_fast_path_stops_at_env_update_by_default(self): + lines = iter( + [ + log_start(), + app_start(), + env_update({"spark.master": "local"}), + sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ] + ) + result = scan_events(lines, budget=100) + assert result.env_update_seen is True + assert result.termination is Termination.CPU_FAST_PATH + assert result.events_scanned == 3 + + def test_cpu_fast_path_can_be_disabled(self): + lines = iter( + [ + log_start(), + app_start(), + env_update({"spark.master": "local"}), + sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + ] + ) + result = scan_events(lines, budget=100, allow_cpu_fast_path=False) + assert result.env_update_seen is True + assert result.termination is Termination.DECISIVE + assert result.events_scanned == 4 + + def test_fast_path_ignored_when_rapids_marker_present(self): + lines = iter( + [ + env_update({ + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "false", + }), + sql_exec_start({"spark.rapids.sql.enabled": "true"}), + ] + ) + result = scan_events(lines, budget=100) + assert result.termination is Termination.DECISIVE + assert result.spark_properties["spark.rapids.sql.enabled"] == "true" + + def test_no_env_update_within_budget_is_cap_hit(self): + # Budget less than the number of events, none of them env-update. + lines = iter([log_start()] * 5) + result = scan_events(lines, budget=2) + assert result.env_update_seen is False + assert result.termination is Termination.CAP_HIT + + def test_no_env_update_to_eof_is_exhausted_without_env(self): + lines = iter([log_start(), app_start()]) + result = scan_events(lines, budget=100) + assert result.env_update_seen is False + assert result.termination is Termination.EXHAUSTED + + def test_malformed_json_lines_are_skipped(self): + lines = iter( + [ + "not-json-at-all", + log_start(), + "", + app_start(), + env_update({"spark.master": "local"}), + ] + ) + result = scan_events(lines, budget=100) + assert result.env_update_seen is True + assert result.app_id == "app-1" + + def test_sql_start_classifies_after_full_modified_config_merge(self): + lines = iter( + [ + env_update({"spark.rapids.sql.enabled": "false"}), + sql_exec_start({ + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "false", + }), + ] + ) + result = scan_events(lines, budget=100) + assert result.termination is Termination.EXHAUSTED + assert result.spark_properties["spark.plugins"] == "com.nvidia.spark.SQLPlugin" + assert result.spark_properties["spark.rapids.sql.enabled"] == "false" + + +def _write(path: Path, lines: List[str]) -> CspPath: + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return CspPath(str(path)) + + +class TestScanEventsAcross: + """Tests for scan_events_across scanning across multiple files.""" + + def test_gpu_signal_in_second_file_is_decisive(self, tmp_path): + f1 = _write( + tmp_path / "events_1_app-1", + [log_start(), app_start(), env_update({"spark.rapids.sql.enabled": "false"})], + ) + f2 = _write(tmp_path / "events_2_app-1", [rapids_build_info()]) + result = scan_events_across([f1, f2], budget=100) + assert result.termination is Termination.DECISIVE + assert result.last_scanned_path == str(f2) + + def test_cpu_fast_path_applies_across_files_when_no_rapids_markers(self, tmp_path): + f1 = _write(tmp_path / "events_1_app-1", [env_update({"spark.master": "local"})]) + f2 = _write(tmp_path / "events_2_app-1", [sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"})]) + result = scan_events_across([f1, f2], budget=100) + assert result.termination is Termination.CPU_FAST_PATH + assert result.last_scanned_path == str(f1) + + def test_cpu_fast_path_skips_when_rapids_marker_present_across_files(self, tmp_path): + f1 = _write( + tmp_path / "events_1_app-1", + [env_update({"spark.rapids.sql.enabled": "false"})], + ) + f2 = _write( + tmp_path / "events_2_app-1", + [sql_exec_start({ + "spark.plugins": "com.nvidia.spark.SQLPlugin", + "spark.rapids.sql.enabled": "true", + })], + ) + result = scan_events_across([f1, f2], budget=100) + assert result.termination is Termination.DECISIVE + assert result.last_scanned_path == str(f2) + + def test_shared_budget_applied_across_files(self, tmp_path): + # 3 events in first file, 3 in second. Budget = 4. Second file stops + # after one event, before any GPU signal. + f1 = _write(tmp_path / "events_1_app-1", [log_start(), app_start(), env_update({ + "spark.rapids.sql.enabled": "false", + })]) + f2 = _write( + tmp_path / "events_2_app-1", + [ + sql_exec_start({"spark.master": "still-cpu"}), + sql_exec_start({"spark.plugins": "com.nvidia.spark.SQLPlugin"}), + sql_exec_start({"x": "y"}), + ], + ) + result = scan_events_across([f1, f2], budget=4) + assert result.termination is Termination.CAP_HIT + + def test_all_files_exhausted_returns_exhausted(self, tmp_path): + f1 = _write(tmp_path / "events_1_app-1", [env_update({"spark.master": "local"})]) + result = scan_events_across([f1], budget=100, allow_cpu_fast_path=False) + assert result.termination is Termination.EXHAUSTED diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_stream.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_stream.py new file mode 100644 index 000000000..e43543ded --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_stream.py @@ -0,0 +1,138 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``eventlog_detector.stream``.""" +# pylint: disable=too-few-public-methods # test classes naturally have few methods + +import gzip +from pathlib import Path + +import pytest +import zstandard as zstd + +from spark_rapids_tools.storagelib import CspPath +from spark_rapids_tools.tools.eventlog_detector.stream import open_event_log_stream +from spark_rapids_tools.tools.eventlog_detector.types import ( + EventLogReadError, + UnsupportedCompressionError, +) + + +SAMPLE_LINES = [ + '{"Event":"SparkListenerLogStart","Spark Version":"3.5.1"}', + '{"Event":"SparkListenerApplicationStart","App ID":"app-1"}', + '{"Event":"SparkListenerEnvironmentUpdate","Spark Properties":{}}', +] + + +def _write_plain(path: Path) -> None: + path.write_text("\n".join(SAMPLE_LINES) + "\n", encoding="utf-8") + + +def _write_gz(path: Path) -> None: + with gzip.open(path, "wt", encoding="utf-8") as fh: + fh.write("\n".join(SAMPLE_LINES) + "\n") + + +def _write_zstd(path: Path) -> None: + cctx = zstd.ZstdCompressor() + raw = ("\n".join(SAMPLE_LINES) + "\n").encode("utf-8") + path.write_bytes(cctx.compress(raw)) + + +@pytest.fixture +def plain_file(tmp_path: Path) -> CspPath: + p = tmp_path / "eventlog.inprogress" + _write_plain(p) + return CspPath(str(p)) + + +@pytest.fixture +def gz_file(tmp_path: Path) -> CspPath: + p = tmp_path / "eventlog.gz" + _write_gz(p) + return CspPath(str(p)) + + +@pytest.fixture +def zstd_file(tmp_path: Path) -> CspPath: + p = tmp_path / "eventlog.zstd" + _write_zstd(p) + return CspPath(str(p)) + + +class TestPlainStream: + """Test streaming plain-text event logs.""" + + def test_yields_all_lines(self, plain_file): # pylint: disable=redefined-outer-name + with open_event_log_stream(plain_file) as lines: + collected = list(lines) + assert collected == SAMPLE_LINES + + +class TestGzipStream: + """Test streaming gzip-compressed event logs.""" + + def test_yields_all_lines(self, gz_file): # pylint: disable=redefined-outer-name + with open_event_log_stream(gz_file) as lines: + collected = list(lines) + assert collected == SAMPLE_LINES + + +class TestZstdStream: + """Test streaming zstd-compressed event logs.""" + + def test_yields_all_lines(self, zstd_file): # pylint: disable=redefined-outer-name + with open_event_log_stream(zstd_file) as lines: + collected = list(lines) + assert collected == SAMPLE_LINES + + def test_zst_short_suffix_also_works(self, tmp_path): + p = tmp_path / "eventlog.zst" + _write_zstd(p) + with open_event_log_stream(CspPath(str(p))) as lines: + collected = list(lines) + assert collected == SAMPLE_LINES + + +class TestUnsupportedCompression: + """Test that unsupported compression formats raise UnsupportedCompressionError.""" + + @pytest.mark.parametrize("suffix", [".lz4", ".snappy", ".lzf", ".weirdcodec"]) + def test_unsupported_suffix_raises(self, tmp_path, suffix): + p = tmp_path / f"eventlog{suffix}" + p.write_bytes(b"some-bytes") + with pytest.raises(UnsupportedCompressionError): + with open_event_log_stream(CspPath(str(p))) as _: + pass + + +class TestIoFailure: + """Test that I/O errors raise EventLogReadError.""" + + def test_missing_file_raises_read_error(self, tmp_path): + p = tmp_path / "does-not-exist" + with pytest.raises(EventLogReadError): + with open_event_log_stream(CspPath(str(p))) as lines: + next(iter(lines)) + + def test_caller_side_exception_is_not_reclassified(self, plain_file): # pylint: disable=redefined-outer-name + # Caller-raised exceptions must propagate untouched, not be + # reclassified as EventLogReadError. + class _MarkerError(RuntimeError): + pass + + with pytest.raises(_MarkerError): + with open_event_log_stream(plain_file): + raise _MarkerError("not an I/O failure") diff --git a/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_types.py b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_types.py new file mode 100644 index 000000000..0df975e63 --- /dev/null +++ b/user_tools/tests/spark_rapids_tools_ut/tools/eventlog_detector/test_types.py @@ -0,0 +1,65 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``eventlog_detector.types``.""" +# pylint: disable=too-few-public-methods # test classes naturally have few methods + +from spark_rapids_tools.tools.eventlog_detector.types import ( + EventLogDetectionError, + EventLogReadError, + SparkRuntime, + ToolExecution, + UnsupportedCompressionError, + UnsupportedInputError, +) + + +class TestToolExecution: + """Test the ToolExecution string enum.""" + + def test_has_expected_values(self): + assert {r.value for r in ToolExecution} == { + "QUALIFICATION", + "PROFILING", + "UNKNOWN", + } + + def test_is_string_enum(self): + # str subclass means aether can compare against plain strings. + assert ToolExecution.PROFILING == "PROFILING" + + +class TestSparkRuntime: + """Test the reduced SparkRuntime string enum.""" + + def test_values_cover_spark_and_rapids_only(self): + assert {r.value for r in SparkRuntime} == { + "SPARK", + "SPARK_RAPIDS", + } + + def test_is_string_enum(self): + assert SparkRuntime.SPARK_RAPIDS == "SPARK_RAPIDS" + + +class TestExceptionHierarchy: + """Test that all detector exceptions form a coherent hierarchy.""" + + def test_all_errors_subclass_base(self): + for cls in ( + UnsupportedInputError, + UnsupportedCompressionError, + EventLogReadError, + ): + assert issubclass(cls, EventLogDetectionError)