Skip to content

Commit 4eef7a9

Browse files
committed
Adding validation for user provided settings.
Some invalid settings could make the compiler hang.
1 parent 4d41dcc commit 4eef7a9

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

helion/runtime/settings.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
import time
99
from typing import TYPE_CHECKING
10+
from typing import Callable
1011
from typing import Literal
1112
from typing import Protocol
1213
from typing import Sequence
@@ -16,6 +17,7 @@
1617
from torch._environment import is_fbcode
1718

1819
from helion import exc
20+
from helion.autotuner.effort_profile import _PROFILES
1921
from helion.autotuner.effort_profile import AutotuneEffort
2022
from helion.autotuner.effort_profile import get_effort_profile
2123
from helion.runtime.ref_mode import RefMode
@@ -128,8 +130,16 @@ def _get_autotune_rebenchmark_threshold() -> float | None:
128130
return None # Will use effort profile default
129131

130132

133+
def _normalize_autotune_effort(value: object) -> AutotuneEffort:
134+
if isinstance(value, str):
135+
normalized = value.lower()
136+
if normalized in _PROFILES:
137+
return cast("AutotuneEffort", normalized)
138+
raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'")
139+
140+
131141
def _get_autotune_effort() -> AutotuneEffort:
132-
return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))
142+
return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full"))
133143

134144

135145
@dataclasses.dataclass
@@ -182,6 +192,38 @@ class _Settings:
182192
)
183193
autotuner_fn: AutotunerFunction = default_autotuner_fn
184194

195+
def __post_init__(self) -> None:
196+
def _is_bool(val: object) -> bool:
197+
return isinstance(val, bool)
198+
199+
def _is_non_negative_int(val: object) -> bool:
200+
return isinstance(val, int) and val >= 0
201+
202+
# Validate user settings
203+
validators: dict[str, Callable[[object], bool]] = {
204+
"autotune_log_level": _is_non_negative_int,
205+
"autotune_compile_timeout": _is_non_negative_int,
206+
"autotune_precompile": _is_bool,
207+
"autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v),
208+
"autotune_accuracy_check": _is_bool,
209+
"autotune_progress_bar": _is_bool,
210+
"autotune_max_generations": lambda v: v is None or _is_non_negative_int(v),
211+
"print_output_code": _is_bool,
212+
"force_autotune": _is_bool,
213+
"allow_warp_specialize": _is_bool,
214+
"debug_dtype_asserts": _is_bool,
215+
"autotune_rebenchmark_threshold": lambda v: v is None
216+
or (isinstance(v, (int, float)) and v >= 0),
217+
}
218+
219+
normalized_effort = _normalize_autotune_effort(self.autotune_effort)
220+
object.__setattr__(self, "autotune_effort", normalized_effort)
221+
222+
for field_name, checker in validators.items():
223+
value = getattr(self, field_name)
224+
if not checker(value):
225+
raise ValueError(f"Invalid value for {field_name}: {value!r}")
226+
185227

186228
class Settings(_Settings):
187229
"""

test/test_settings.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import helion
6+
7+
8+
class TestSettingsValidation(unittest.TestCase):
9+
def test_autotune_effort_none_raises(self) -> None:
10+
with self.assertRaisesRegex(
11+
ValueError, "autotune_effort must be one of 'none', 'quick', or 'full'"
12+
):
13+
helion.Settings(autotune_effort=None)
14+
15+
def test_autotune_effort_quick_normalized(self) -> None:
16+
settings = helion.Settings(autotune_effort="Quick")
17+
self.assertEqual(settings.autotune_effort, "quick")
18+
19+
def test_negative_compile_timeout_raises(self) -> None:
20+
with self.assertRaisesRegex(
21+
ValueError, r"Invalid value for autotune_compile_timeout: -1"
22+
):
23+
helion.Settings(autotune_compile_timeout=-1)
24+
25+
def test_autotune_precompile_jobs_negative_raises(self) -> None:
26+
with self.assertRaisesRegex(
27+
ValueError, r"Invalid value for autotune_precompile_jobs: -1"
28+
):
29+
helion.Settings(autotune_precompile_jobs=-1)
30+
31+
def test_autotune_max_generations_negative_raises(self) -> None:
32+
with self.assertRaisesRegex(
33+
ValueError, r"Invalid value for autotune_max_generations: -1"
34+
):
35+
helion.Settings(autotune_max_generations=-1)

0 commit comments

Comments
 (0)