Skip to content

Commit 0bafd91

Browse files
authored
Add settings.autotune_baseline_fn to allow passing in custom baseline function to autotuner (#1054)
1 parent 2cb6d17 commit 0bafd91

File tree

3 files changed

+186
-22
lines changed

3 files changed

+186
-22
lines changed

helion/autotuner/base_search.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -151,31 +151,49 @@ def _clone_leaf(leaf: object) -> object:
151151

152152
def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
153153
"""
154-
Return output and post-run input arguments of the default-config kernel.
154+
Compute baseline output for accuracy validation during autotuning.
155155
Also detect if the kernel mutates any of its input arguments.
156+
157+
The baseline is computed in one of two ways:
158+
- If settings.autotune_baseline_fn is provided, use that custom function
159+
- Otherwise, run the kernel with the default config
156160
"""
157161
new_args = self._clone_args(self._original_args)
158-
baseline_config = self.config_spec.default_config()
159-
try:
160-
baseline_output = self.kernel.compile_config(
161-
baseline_config, allow_print=False
162-
)(*new_args)
163-
torch.accelerator.synchronize()
164-
except Exception as e:
165-
decorator = self.kernel.format_kernel_decorator(
166-
baseline_config, self.settings
167-
)
168-
log_generated_triton_code_debug(
169-
self.log,
170-
self.kernel,
171-
baseline_config,
172-
prefix=f"Generated Triton code for {decorator}:",
173-
)
174-
raise exc.InvalidConfig(
175-
"Default config failed while computing baseline.\n"
176-
f"Default config: {decorator}\n"
177-
f"{SUPPRESSED_TRITON_CODE_MSG}\n"
178-
) from e
162+
163+
# Use custom baseline function if provided
164+
if self.settings.autotune_baseline_fn is not None:
165+
try:
166+
baseline_output = self.settings.autotune_baseline_fn(*new_args)
167+
torch.accelerator.synchronize()
168+
except Exception as e:
169+
raise exc.AutotuneError(
170+
"Custom baseline function failed while computing baseline.\n"
171+
f"Baseline function: {self.settings.autotune_baseline_fn}\n"
172+
) from e
173+
else:
174+
# Use default config
175+
baseline_config = self.config_spec.default_config()
176+
try:
177+
baseline_output = self.kernel.compile_config(
178+
baseline_config, allow_print=False
179+
)(*new_args)
180+
torch.accelerator.synchronize()
181+
except Exception as e:
182+
decorator = self.kernel.format_kernel_decorator(
183+
baseline_config, self.settings
184+
)
185+
log_generated_triton_code_debug(
186+
self.log,
187+
self.kernel,
188+
baseline_config,
189+
prefix=f"Generated Triton code for {decorator}:",
190+
)
191+
raise exc.InvalidConfig(
192+
"Default config failed while computing baseline.\n"
193+
f"Default config: {decorator}\n"
194+
f"{SUPPRESSED_TRITON_CODE_MSG}\n"
195+
) from e
196+
179197
original_args_flat, _ = tree_flatten(self._original_args)
180198
new_args_flat, _ = tree_flatten(new_args)
181199
mutated = False

helion/runtime/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import time
99
from typing import TYPE_CHECKING
10+
from typing import Callable
1011
from typing import Literal
1112
from typing import Protocol
1213
from typing import Sequence
@@ -345,6 +346,7 @@ class _Settings:
345346
)
346347
ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
347348
autotuner_fn: AutotunerFunction = default_autotuner_fn
349+
autotune_baseline_fn: Callable[..., object] | None = None
348350

349351

350352
class Settings(_Settings):
@@ -401,6 +403,12 @@ class Settings(_Settings):
401403
"Override by passing a callable to @helion.kernel(..., autotuner_fn=...)."
402404
),
403405
"autotune_effort": "Autotuning effort preset. One of 'none', 'quick', 'full'.",
406+
"autotune_baseline_fn": (
407+
"Custom baseline function for computing baseline output during autotuning. "
408+
"If provided, this function will be called instead of running the default config. "
409+
"Should have the same signature as the kernel function. "
410+
"Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)."
411+
),
404412
}
405413

406414
def __init__(self, **settings: object) -> None:

test/test_autotuner.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,144 @@ def wrong_fn(*fn_args, **fn_kwargs):
591591
run_mode("fork", expect_error=False)
592592
run_mode("spawn", expect_error=True)
593593

594+
def test_autotune_baseline_fn(self) -> None:
595+
"""Test that custom baseline function is used for accuracy checking."""
596+
config1 = helion.Config(block_sizes=[32], num_warps=4)
597+
config2 = helion.Config(block_sizes=[64], num_warps=8)
598+
599+
# Track whether the baseline function was called
600+
baseline_calls = []
601+
602+
def custom_baseline(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
603+
baseline_calls.append(True)
604+
# Return the expected result using PyTorch operations
605+
return a + b
606+
607+
@helion.kernel(
608+
configs=[config1, config2],
609+
autotune_baseline_fn=custom_baseline,
610+
autotune_log_level=0,
611+
)
612+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
613+
out = torch.empty_like(a)
614+
for tile in hl.tile(out.size()):
615+
out[tile] = a[tile] + b[tile]
616+
return out
617+
618+
args = (
619+
torch.randn([128], device=DEVICE),
620+
torch.randn([128], device=DEVICE),
621+
)
622+
623+
# Run autotuning
624+
result = add(*args)
625+
626+
# Verify the custom baseline function was called during autotuning
627+
self.assertGreater(
628+
len(baseline_calls), 0, "Custom baseline function should be called"
629+
)
630+
631+
# Verify the result is correct
632+
torch.testing.assert_close(result, args[0] + args[1])
633+
634+
def test_autotune_baseline_fn_filters_bad_config(self) -> None:
635+
"""Test that custom baseline function correctly filters incorrect configs."""
636+
bad_config = helion.Config(block_sizes=[1], num_warps=8)
637+
good_config = helion.Config(block_sizes=[1], num_warps=4)
638+
639+
def custom_baseline(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # noqa: FURB118
640+
# Return the correct expected result
641+
return a + b
642+
643+
@helion.kernel(
644+
configs=[bad_config, good_config],
645+
autotune_baseline_fn=custom_baseline,
646+
autotune_log_level=0,
647+
)
648+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
649+
out = torch.empty_like(a)
650+
for tile in hl.tile(out.size()):
651+
out[tile] = a[tile] + b[tile]
652+
return out
653+
654+
a = torch.randn([32], device=DEVICE)
655+
b = torch.randn([32], device=DEVICE)
656+
bound_kernel = add.bind((a, b))
657+
original_compile = bound_kernel.compile_config
658+
bound_kernel.settings.autotune_precompile = "fork"
659+
660+
# Make bad_config produce wrong output
661+
def make_bad_config_produce_wrong_output(
662+
config: helion.Config, *, allow_print: bool = True
663+
):
664+
fn = original_compile(config, allow_print=allow_print)
665+
if config == bad_config:
666+
return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1
667+
return fn
668+
669+
import helion.autotuner.base_search as base_search_module
670+
671+
with patch.object(
672+
bound_kernel,
673+
"compile_config",
674+
side_effect=make_bad_config_produce_wrong_output,
675+
):
676+
search = FiniteSearch(
677+
bound_kernel, (a, b), configs=[bad_config, good_config]
678+
)
679+
with patch.object(
680+
search,
681+
"start_precompile_and_check_for_hangs",
682+
side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip(
683+
search, config, True
684+
),
685+
):
686+
# Bad config should be filtered out by accuracy check
687+
_, bad_time = search.benchmark(bad_config)
688+
self.assertTrue(math.isinf(bad_time))
689+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
690+
691+
# Good config should pass accuracy check
692+
search.counters["accuracy_mismatch"] = 0
693+
_, good_time = search.benchmark(good_config)
694+
self.assertFalse(math.isinf(good_time))
695+
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0)
696+
697+
# Autotuning should select the good config
698+
best = search.autotune()
699+
self.assertEqual(best, good_config)
700+
701+
def test_autotune_baseline_fn_raises_on_failure(self) -> None:
702+
"""Test that AutotuneError is raised when custom baseline function fails."""
703+
config1 = helion.Config(block_sizes=[32], num_warps=4)
704+
config2 = helion.Config(block_sizes=[64], num_warps=8)
705+
706+
def failing_baseline(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
707+
raise RuntimeError("Baseline computation failed!")
708+
709+
@helion.kernel(
710+
configs=[config1, config2],
711+
autotune_baseline_fn=failing_baseline,
712+
autotune_log_level=0,
713+
)
714+
def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
715+
out = torch.empty_like(a)
716+
for tile in hl.tile(out.size()):
717+
out[tile] = a[tile] + b[tile]
718+
return out
719+
720+
args = (
721+
torch.randn([128], device=DEVICE),
722+
torch.randn([128], device=DEVICE),
723+
)
724+
725+
# Attempting to run should raise AutotuneError
726+
with self.assertRaisesRegex(
727+
helion.exc.AutotuneError,
728+
"Custom baseline function failed while computing baseline",
729+
):
730+
add(*args)
731+
594732
def test_max_generations(self):
595733
"""Autotuner max generation respects explicit kwargs then setting override."""
596734

0 commit comments

Comments
 (0)