diff --git a/helion/autotuner/base_cache.py b/helion/autotuner/base_cache.py index d6f35f8ae..0aeca2c71 100644 --- a/helion/autotuner/base_cache.py +++ b/helion/autotuner/base_cache.py @@ -178,7 +178,10 @@ def autotune(self, *, skip_cache: bool = False) -> Config: counters["autotune"]["cache_miss"] += 1 log.debug("cache miss") - self.autotuner.log("Starting autotuning process, this may take a while...") + effort = self.kernel.settings.autotune_effort + self.autotuner.log( + f"Starting autotuning process with effort={effort}, this may take a while..." + ) config = self.autotuner.autotune() diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 69ebdbc89..44c79bca3 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -952,6 +952,7 @@ def _wait_for_all_step( ) -> list[PrecompileFuture]: """Start up to the concurrency cap, wait for progress, and return remaining futures.""" cap = futures[0].search._jobs if futures else 1 + assert cap > 0, "autotune_precompile_jobs must be positive" running = [f for f in futures if f.started and f.ok is None and f.is_alive()] # Start queued futures up to the cap diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index b875caebf..2c8260cdd 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -578,7 +578,7 @@ def _implicit_config(self) -> Config | None: if not is_ref_mode_enabled(self.kernel.settings): kernel_decorator = self.format_kernel_decorator(config, self.settings) print( - f"Using default config: {kernel_decorator}", + f"Using default config (autotune_effort=none): {kernel_decorator}", file=sys.stderr, ) return config diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 55cd02dca..d8a21a447 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -6,6 +6,8 @@ import threading import time from typing import TYPE_CHECKING +from typing import Callable +from typing import Collection from typing import Literal from typing import Protocol from typing import Sequence @@ -15,6 +17,7 @@ from torch._environment import is_fbcode from .. import exc +from ..autotuner.effort_profile import _PROFILES from ..autotuner.effort_profile import AutotuneEffort from ..autotuner.effort_profile import get_effort_profile from .ref_mode import RefMode @@ -34,6 +37,45 @@ def __call__( ) -> BaseAutotuner: ... +def _validate_enum_setting( + value: object, + *, + name: str, + valid: Collection[str], + allow_none: bool = True, +) -> str | None: + """Normalize and validate an enum setting. + + Args: + value: The value to normalize and validate + name: Name of the setting + valid: Collection of valid settings + allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error. + """ + # String values that should be treated as None + _NONE_VALUES = frozenset({"", "0", "false", "none"}) + + # Normalize values + normalized: str | None + if isinstance(value, str): + normalized = value.strip().lower() + else: + normalized = None + + is_none_value = normalized is None or normalized in _NONE_VALUES + is_valid = normalized in valid if normalized else False + + # Valid value (none or valid setting) + if is_none_value and allow_none: + return None + if is_valid: + return normalized + + # Invalid value, raise error + valid_list = "', '".join(sorted(valid)) + raise ValueError(f"{name} must be one of '{valid_list}', got {value!r}") + + _tls: _TLS = cast("_TLS", threading.local()) @@ -106,55 +148,6 @@ def default_autotuner_fn( return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] -def _get_autotune_random_seed() -> int: - value = os.environ.get("HELION_AUTOTUNE_RANDOM_SEED") - if value is not None: - return int(value) - return int(time.time() * 1000) % 2**32 - - -def _get_autotune_max_generations() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS") - if value is not None: - return int(value) - return None - - -def _get_autotune_rebenchmark_threshold() -> float | None: - value = os.environ.get("HELION_REBENCHMARK_THRESHOLD") - if value is not None: - return float(value) - return None # Will use effort profile default - - -def _get_autotune_effort() -> AutotuneEffort: - return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) - - -def _get_autotune_precompile() -> str | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE") - if value is None: - return "spawn" - mode = value.strip().lower() - if mode in {"", "0", "false", "none"}: - return None - if mode in {"spawn", "fork"}: - return mode - raise ValueError( - "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile" - ) - - -def _get_autotune_precompile_jobs() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS") - if value is None or value.strip() == "": - return None - jobs = int(value) - if jobs <= 0: - raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer") - return jobs - - @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) @@ -172,25 +165,35 @@ class _Settings: os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60") ) autotune_precompile: str | None = dataclasses.field( - default_factory=_get_autotune_precompile + default_factory=lambda: os.environ.get("HELION_AUTOTUNE_PRECOMPILE", "spawn") ) autotune_precompile_jobs: int | None = dataclasses.field( - default_factory=_get_autotune_precompile_jobs + default_factory=lambda: int(v) + if (v := os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")) + else None ) autotune_random_seed: int = dataclasses.field( - default_factory=_get_autotune_random_seed + default_factory=lambda: ( + int(v) + if (v := os.environ.get("HELION_AUTOTUNE_RANDOM_SEED")) + else int(time.time() * 1000) % 2**32 + ) ) autotune_accuracy_check: bool = ( os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1" ) autotune_rebenchmark_threshold: float | None = dataclasses.field( - default_factory=_get_autotune_rebenchmark_threshold + default_factory=lambda: float(v) + if (v := os.environ.get("HELION_REBENCHMARK_THRESHOLD")) + else None ) autotune_progress_bar: bool = ( os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1" ) autotune_max_generations: int | None = dataclasses.field( - default_factory=_get_autotune_max_generations + default_factory=lambda: int(v) + if (v := os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS")) + else None ) print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" @@ -198,7 +201,9 @@ class _Settings: default_factory=dict ) autotune_effort: AutotuneEffort = dataclasses.field( - default_factory=_get_autotune_effort + default_factory=lambda: cast( + "AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full") + ) ) allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" @@ -209,6 +214,46 @@ class _Settings: ) autotuner_fn: AutotunerFunction = default_autotuner_fn + def __post_init__(self) -> None: + # Validate all user settings + + self.autotune_effort = cast( + "AutotuneEffort", + _validate_enum_setting( + self.autotune_effort, + name="autotune_effort", + valid=_PROFILES.keys(), + allow_none=False, # do not allow None as "none" is a non-default setting + ), + ) + self.autotune_precompile = _validate_enum_setting( + self.autotune_precompile, + name="autotune_precompile", + valid={"spawn", "fork"}, + ) + + validators: dict[str, Callable[[object], bool]] = { + "autotune_log_level": lambda v: isinstance(v, int) and v >= 0, + "autotune_compile_timeout": lambda v: isinstance(v, int) and v > 0, + "autotune_precompile_jobs": lambda v: v is None + or (isinstance(v, int) and v > 0), + "autotune_accuracy_check": lambda v: isinstance(v, bool), + "autotune_progress_bar": lambda v: isinstance(v, bool), + "autotune_max_generations": lambda v: v is None + or (isinstance(v, int) and v >= 0), + "print_output_code": lambda v: isinstance(v, bool), + "force_autotune": lambda v: isinstance(v, bool), + "allow_warp_specialize": lambda v: isinstance(v, bool), + "debug_dtype_asserts": lambda v: isinstance(v, bool), + "autotune_rebenchmark_threshold": lambda v: v is None + or (isinstance(v, (int, float)) and v >= 0), + } + + for field_name, validator in validators.items(): + value = getattr(self, field_name) + if not validator(value): + raise ValueError(f"Invalid value for {field_name}: {value!r}") + class Settings(_Settings): """ diff --git a/test/test_settings.py b/test/test_settings.py new file mode 100644 index 000000000..670775488 --- /dev/null +++ b/test/test_settings.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import unittest + +import helion + + +class TestSettingsValidation(unittest.TestCase): + def test_autotune_effort_none_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"): + helion.Settings(autotune_effort=None) + + def test_autotune_effort_case_insensitive(self) -> None: + settings = helion.Settings(autotune_effort="Quick") + self.assertEqual(settings.autotune_effort, "quick") + + def test_negative_compile_timeout_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_compile_timeout: -1" + ): + helion.Settings(autotune_compile_timeout=-1) + + def test_autotune_precompile_jobs_negative_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_precompile_jobs: -1" + ): + helion.Settings(autotune_precompile_jobs=-1) + + def test_autotune_max_generations_negative_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_max_generations: -1" + ): + helion.Settings(autotune_max_generations=-1) + + def test_autotune_effort_invalid_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"): + helion.Settings(autotune_effort="super-fast")