From 6067be7fa3551467f0ce44ca0dc033619cba97a4 Mon Sep 17 00:00:00 2001 From: Jongsok Choi Date: Wed, 15 Oct 2025 20:31:28 +0000 Subject: [PATCH 1/5] Adding validation for user provided settings. Some invalid settings could make the compiler hang. --- helion/runtime/settings.py | 44 +++++++++++++++++++++++++++++++++++++- test/test_settings.py | 35 ++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 test/test_settings.py diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 55cd02dca..6f069691e 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -6,6 +6,7 @@ import threading import time from typing import TYPE_CHECKING +from typing import Callable from typing import Literal from typing import Protocol from typing import Sequence @@ -15,6 +16,7 @@ from torch._environment import is_fbcode from .. import exc +from ..autotuner.effort_profile import _PROFILES from ..autotuner.effort_profile import AutotuneEffort from ..autotuner.effort_profile import get_effort_profile from .ref_mode import RefMode @@ -127,8 +129,16 @@ def _get_autotune_rebenchmark_threshold() -> float | None: return None # Will use effort profile default +def _normalize_autotune_effort(value: object) -> AutotuneEffort: + if isinstance(value, str): + normalized = value.lower() + if normalized in _PROFILES: + return cast("AutotuneEffort", normalized) + raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'") + + def _get_autotune_effort() -> AutotuneEffort: - return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) + return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) def _get_autotune_precompile() -> str | None: @@ -209,6 +219,38 @@ class _Settings: ) autotuner_fn: AutotunerFunction = default_autotuner_fn + def __post_init__(self) -> None: + def _is_bool(val: object) -> bool: + return isinstance(val, bool) + + def _is_non_negative_int(val: object) -> bool: + return isinstance(val, int) and val >= 0 + + # Validate user settings + validators: dict[str, Callable[[object], bool]] = { + "autotune_log_level": _is_non_negative_int, + "autotune_compile_timeout": _is_non_negative_int, + "autotune_precompile": _is_bool, + "autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v), + "autotune_accuracy_check": _is_bool, + "autotune_progress_bar": _is_bool, + "autotune_max_generations": lambda v: v is None or _is_non_negative_int(v), + "print_output_code": _is_bool, + "force_autotune": _is_bool, + "allow_warp_specialize": _is_bool, + "debug_dtype_asserts": _is_bool, + "autotune_rebenchmark_threshold": lambda v: v is None + or (isinstance(v, (int, float)) and v >= 0), + } + + normalized_effort = _normalize_autotune_effort(self.autotune_effort) + object.__setattr__(self, "autotune_effort", normalized_effort) + + for field_name, checker in validators.items(): + value = getattr(self, field_name) + if not checker(value): + raise ValueError(f"Invalid value for {field_name}: {value!r}") + class Settings(_Settings): """ diff --git a/test/test_settings.py b/test/test_settings.py new file mode 100644 index 000000000..b506cd126 --- /dev/null +++ b/test/test_settings.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import unittest + +import helion + + +class TestSettingsValidation(unittest.TestCase): + def test_autotune_effort_none_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, "autotune_effort must be one of 'none', 'quick', or 'full'" + ): + helion.Settings(autotune_effort=None) + + def test_autotune_effort_quick_normalized(self) -> None: + settings = helion.Settings(autotune_effort="Quick") + self.assertEqual(settings.autotune_effort, "quick") + + def test_negative_compile_timeout_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_compile_timeout: -1" + ): + helion.Settings(autotune_compile_timeout=-1) + + def test_autotune_precompile_jobs_negative_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_precompile_jobs: -1" + ): + helion.Settings(autotune_precompile_jobs=-1) + + def test_autotune_max_generations_negative_raises(self) -> None: + with self.assertRaisesRegex( + ValueError, r"Invalid value for autotune_max_generations: -1" + ): + helion.Settings(autotune_max_generations=-1) From b102caa4903988f9dcf62d5b137005e646f835db Mon Sep 17 00:00:00 2001 From: Jongsok Choi Date: Wed, 15 Oct 2025 20:34:04 +0000 Subject: [PATCH 2/5] Let user know which autotune_effort is being used. --- helion/autotuner/base_cache.py | 5 ++++- helion/runtime/kernel.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/helion/autotuner/base_cache.py b/helion/autotuner/base_cache.py index d6f35f8ae..0aeca2c71 100644 --- a/helion/autotuner/base_cache.py +++ b/helion/autotuner/base_cache.py @@ -178,7 +178,10 @@ def autotune(self, *, skip_cache: bool = False) -> Config: counters["autotune"]["cache_miss"] += 1 log.debug("cache miss") - self.autotuner.log("Starting autotuning process, this may take a while...") + effort = self.kernel.settings.autotune_effort + self.autotuner.log( + f"Starting autotuning process with effort={effort}, this may take a while..." + ) config = self.autotuner.autotune() diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index b875caebf..2c8260cdd 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -578,7 +578,7 @@ def _implicit_config(self) -> Config | None: if not is_ref_mode_enabled(self.kernel.settings): kernel_decorator = self.format_kernel_decorator(config, self.settings) print( - f"Using default config: {kernel_decorator}", + f"Using default config (autotune_effort=none): {kernel_decorator}", file=sys.stderr, ) return config From b0e978aededd9bac611e3e69f612003a075bcf9a Mon Sep 17 00:00:00 2001 From: Jongsok Choi Date: Wed, 15 Oct 2025 22:08:49 +0000 Subject: [PATCH 3/5] Add assert to ensure the loop doesn't get stuck. --- helion/autotuner/base_search.py | 1 + 1 file changed, 1 insertion(+) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 69ebdbc89..44c79bca3 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -952,6 +952,7 @@ def _wait_for_all_step( ) -> list[PrecompileFuture]: """Start up to the concurrency cap, wait for progress, and return remaining futures.""" cap = futures[0].search._jobs if futures else 1 + assert cap > 0, "autotune_precompile_jobs must be positive" running = [f for f in futures if f.started and f.ok is None and f.is_alive()] # Start queued futures up to the cap From 79d46ec1d474c35b1de1df5d74069d4aef70160c Mon Sep 17 00:00:00 2001 From: Jongsok Choi Date: Wed, 15 Oct 2025 22:09:59 +0000 Subject: [PATCH 4/5] Add the new options for autotune_precompiled to be allowed. --- helion/runtime/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 6f069691e..2655b8139 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -230,7 +230,7 @@ def _is_non_negative_int(val: object) -> bool: validators: dict[str, Callable[[object], bool]] = { "autotune_log_level": _is_non_negative_int, "autotune_compile_timeout": _is_non_negative_int, - "autotune_precompile": _is_bool, + "autotune_precompile": lambda v: v in (None, "spawn", "fork"), "autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v), "autotune_accuracy_check": _is_bool, "autotune_progress_bar": _is_bool, From 388fa824b4ab25b938ab20ae1460c997484303a0 Mon Sep 17 00:00:00 2001 From: Jongsok Choi Date: Thu, 16 Oct 2025 05:28:26 +0000 Subject: [PATCH 5/5] Code simplifications for getting and validating user settings. --- helion/runtime/settings.py | 173 +++++++++++++++++++------------------ test/test_settings.py | 10 ++- 2 files changed, 94 insertions(+), 89 deletions(-) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 2655b8139..d8a21a447 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -7,6 +7,7 @@ import time from typing import TYPE_CHECKING from typing import Callable +from typing import Collection from typing import Literal from typing import Protocol from typing import Sequence @@ -36,6 +37,45 @@ def __call__( ) -> BaseAutotuner: ... +def _validate_enum_setting( + value: object, + *, + name: str, + valid: Collection[str], + allow_none: bool = True, +) -> str | None: + """Normalize and validate an enum setting. + + Args: + value: The value to normalize and validate + name: Name of the setting + valid: Collection of valid settings + allow_none: If True, None and _NONE_VALUES strings return None. If False, they raise an error. + """ + # String values that should be treated as None + _NONE_VALUES = frozenset({"", "0", "false", "none"}) + + # Normalize values + normalized: str | None + if isinstance(value, str): + normalized = value.strip().lower() + else: + normalized = None + + is_none_value = normalized is None or normalized in _NONE_VALUES + is_valid = normalized in valid if normalized else False + + # Valid value (none or valid setting) + if is_none_value and allow_none: + return None + if is_valid: + return normalized + + # Invalid value, raise error + valid_list = "', '".join(sorted(valid)) + raise ValueError(f"{name} must be one of '{valid_list}', got {value!r}") + + _tls: _TLS = cast("_TLS", threading.local()) @@ -108,63 +148,6 @@ def default_autotuner_fn( return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] -def _get_autotune_random_seed() -> int: - value = os.environ.get("HELION_AUTOTUNE_RANDOM_SEED") - if value is not None: - return int(value) - return int(time.time() * 1000) % 2**32 - - -def _get_autotune_max_generations() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS") - if value is not None: - return int(value) - return None - - -def _get_autotune_rebenchmark_threshold() -> float | None: - value = os.environ.get("HELION_REBENCHMARK_THRESHOLD") - if value is not None: - return float(value) - return None # Will use effort profile default - - -def _normalize_autotune_effort(value: object) -> AutotuneEffort: - if isinstance(value, str): - normalized = value.lower() - if normalized in _PROFILES: - return cast("AutotuneEffort", normalized) - raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'") - - -def _get_autotune_effort() -> AutotuneEffort: - return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) - - -def _get_autotune_precompile() -> str | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE") - if value is None: - return "spawn" - mode = value.strip().lower() - if mode in {"", "0", "false", "none"}: - return None - if mode in {"spawn", "fork"}: - return mode - raise ValueError( - "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile" - ) - - -def _get_autotune_precompile_jobs() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS") - if value is None or value.strip() == "": - return None - jobs = int(value) - if jobs <= 0: - raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer") - return jobs - - @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) @@ -182,25 +165,35 @@ class _Settings: os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60") ) autotune_precompile: str | None = dataclasses.field( - default_factory=_get_autotune_precompile + default_factory=lambda: os.environ.get("HELION_AUTOTUNE_PRECOMPILE", "spawn") ) autotune_precompile_jobs: int | None = dataclasses.field( - default_factory=_get_autotune_precompile_jobs + default_factory=lambda: int(v) + if (v := os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")) + else None ) autotune_random_seed: int = dataclasses.field( - default_factory=_get_autotune_random_seed + default_factory=lambda: ( + int(v) + if (v := os.environ.get("HELION_AUTOTUNE_RANDOM_SEED")) + else int(time.time() * 1000) % 2**32 + ) ) autotune_accuracy_check: bool = ( os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1" ) autotune_rebenchmark_threshold: float | None = dataclasses.field( - default_factory=_get_autotune_rebenchmark_threshold + default_factory=lambda: float(v) + if (v := os.environ.get("HELION_REBENCHMARK_THRESHOLD")) + else None ) autotune_progress_bar: bool = ( os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1" ) autotune_max_generations: int | None = dataclasses.field( - default_factory=_get_autotune_max_generations + default_factory=lambda: int(v) + if (v := os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS")) + else None ) print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" @@ -208,7 +201,9 @@ class _Settings: default_factory=dict ) autotune_effort: AutotuneEffort = dataclasses.field( - default_factory=_get_autotune_effort + default_factory=lambda: cast( + "AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full") + ) ) allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" @@ -220,35 +215,43 @@ class _Settings: autotuner_fn: AutotunerFunction = default_autotuner_fn def __post_init__(self) -> None: - def _is_bool(val: object) -> bool: - return isinstance(val, bool) - - def _is_non_negative_int(val: object) -> bool: - return isinstance(val, int) and val >= 0 + # Validate all user settings + + self.autotune_effort = cast( + "AutotuneEffort", + _validate_enum_setting( + self.autotune_effort, + name="autotune_effort", + valid=_PROFILES.keys(), + allow_none=False, # do not allow None as "none" is a non-default setting + ), + ) + self.autotune_precompile = _validate_enum_setting( + self.autotune_precompile, + name="autotune_precompile", + valid={"spawn", "fork"}, + ) - # Validate user settings validators: dict[str, Callable[[object], bool]] = { - "autotune_log_level": _is_non_negative_int, - "autotune_compile_timeout": _is_non_negative_int, - "autotune_precompile": lambda v: v in (None, "spawn", "fork"), - "autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v), - "autotune_accuracy_check": _is_bool, - "autotune_progress_bar": _is_bool, - "autotune_max_generations": lambda v: v is None or _is_non_negative_int(v), - "print_output_code": _is_bool, - "force_autotune": _is_bool, - "allow_warp_specialize": _is_bool, - "debug_dtype_asserts": _is_bool, + "autotune_log_level": lambda v: isinstance(v, int) and v >= 0, + "autotune_compile_timeout": lambda v: isinstance(v, int) and v > 0, + "autotune_precompile_jobs": lambda v: v is None + or (isinstance(v, int) and v > 0), + "autotune_accuracy_check": lambda v: isinstance(v, bool), + "autotune_progress_bar": lambda v: isinstance(v, bool), + "autotune_max_generations": lambda v: v is None + or (isinstance(v, int) and v >= 0), + "print_output_code": lambda v: isinstance(v, bool), + "force_autotune": lambda v: isinstance(v, bool), + "allow_warp_specialize": lambda v: isinstance(v, bool), + "debug_dtype_asserts": lambda v: isinstance(v, bool), "autotune_rebenchmark_threshold": lambda v: v is None or (isinstance(v, (int, float)) and v >= 0), } - normalized_effort = _normalize_autotune_effort(self.autotune_effort) - object.__setattr__(self, "autotune_effort", normalized_effort) - - for field_name, checker in validators.items(): + for field_name, validator in validators.items(): value = getattr(self, field_name) - if not checker(value): + if not validator(value): raise ValueError(f"Invalid value for {field_name}: {value!r}") diff --git a/test/test_settings.py b/test/test_settings.py index b506cd126..670775488 100644 --- a/test/test_settings.py +++ b/test/test_settings.py @@ -7,12 +7,10 @@ class TestSettingsValidation(unittest.TestCase): def test_autotune_effort_none_raises(self) -> None: - with self.assertRaisesRegex( - ValueError, "autotune_effort must be one of 'none', 'quick', or 'full'" - ): + with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"): helion.Settings(autotune_effort=None) - def test_autotune_effort_quick_normalized(self) -> None: + def test_autotune_effort_case_insensitive(self) -> None: settings = helion.Settings(autotune_effort="Quick") self.assertEqual(settings.autotune_effort, "quick") @@ -33,3 +31,7 @@ def test_autotune_max_generations_negative_raises(self) -> None: ValueError, r"Invalid value for autotune_max_generations: -1" ): helion.Settings(autotune_max_generations=-1) + + def test_autotune_effort_invalid_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "autotune_effort must be one of"): + helion.Settings(autotune_effort="super-fast")