|
7 | 7 | import threading |
8 | 8 | import time |
9 | 9 | from typing import TYPE_CHECKING |
| 10 | +from typing import Callable |
10 | 11 | from typing import Literal |
11 | 12 | from typing import Protocol |
12 | 13 | from typing import Sequence |
|
16 | 17 | from torch._environment import is_fbcode |
17 | 18 |
|
18 | 19 | from helion import exc |
| 20 | +from helion.autotuner.effort_profile import _PROFILES |
19 | 21 | from helion.autotuner.effort_profile import AutotuneEffort |
20 | 22 | from helion.autotuner.effort_profile import get_effort_profile |
21 | 23 | from helion.runtime.ref_mode import RefMode |
@@ -128,8 +130,16 @@ def _get_autotune_rebenchmark_threshold() -> float | None: |
128 | 130 | return None # Will use effort profile default |
129 | 131 |
|
130 | 132 |
|
| 133 | +def _normalize_autotune_effort(value: object) -> AutotuneEffort: |
| 134 | + if isinstance(value, str): |
| 135 | + normalized = value.lower() |
| 136 | + if normalized in _PROFILES: |
| 137 | + return cast("AutotuneEffort", normalized) |
| 138 | + raise ValueError("autotune_effort must be one of 'none', 'quick', or 'full'") |
| 139 | + |
| 140 | + |
131 | 141 | def _get_autotune_effort() -> AutotuneEffort: |
132 | | - return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) |
| 142 | + return _normalize_autotune_effort(os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) |
133 | 143 |
|
134 | 144 |
|
135 | 145 | @dataclasses.dataclass |
@@ -182,6 +192,38 @@ class _Settings: |
182 | 192 | ) |
183 | 193 | autotuner_fn: AutotunerFunction = default_autotuner_fn |
184 | 194 |
|
| 195 | + def __post_init__(self) -> None: |
| 196 | + def _is_bool(val: object) -> bool: |
| 197 | + return isinstance(val, bool) |
| 198 | + |
| 199 | + def _is_non_negative_int(val: object) -> bool: |
| 200 | + return isinstance(val, int) and val >= 0 |
| 201 | + |
| 202 | + # Validate user settings |
| 203 | + validators: dict[str, Callable[[object], bool]] = { |
| 204 | + "autotune_log_level": _is_non_negative_int, |
| 205 | + "autotune_compile_timeout": _is_non_negative_int, |
| 206 | + "autotune_precompile": _is_bool, |
| 207 | + "autotune_precompile_jobs": lambda v: v is None or _is_non_negative_int(v), |
| 208 | + "autotune_accuracy_check": _is_bool, |
| 209 | + "autotune_progress_bar": _is_bool, |
| 210 | + "autotune_max_generations": lambda v: v is None or _is_non_negative_int(v), |
| 211 | + "print_output_code": _is_bool, |
| 212 | + "force_autotune": _is_bool, |
| 213 | + "allow_warp_specialize": _is_bool, |
| 214 | + "debug_dtype_asserts": _is_bool, |
| 215 | + "autotune_rebenchmark_threshold": lambda v: v is None |
| 216 | + or (isinstance(v, (int, float)) and v >= 0), |
| 217 | + } |
| 218 | + |
| 219 | + normalized_effort = _normalize_autotune_effort(self.autotune_effort) |
| 220 | + object.__setattr__(self, "autotune_effort", normalized_effort) |
| 221 | + |
| 222 | + for field_name, checker in validators.items(): |
| 223 | + value = getattr(self, field_name) |
| 224 | + if not checker(value): |
| 225 | + raise ValueError(f"Invalid value for {field_name}: {value!r}") |
| 226 | + |
185 | 227 |
|
186 | 228 | class Settings(_Settings): |
187 | 229 | """ |
|
0 commit comments