@@ -591,6 +591,144 @@ def wrong_fn(*fn_args, **fn_kwargs):
591591 run_mode ("fork" , expect_error = False )
592592 run_mode ("spawn" , expect_error = True )
593593
594+ def test_autotune_baseline_fn (self ) -> None :
595+ """Test that custom baseline function is used for accuracy checking."""
596+ config1 = helion .Config (block_sizes = [32 ], num_warps = 4 )
597+ config2 = helion .Config (block_sizes = [64 ], num_warps = 8 )
598+
599+ # Track whether the baseline function was called
600+ baseline_calls = []
601+
602+ def custom_baseline (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
603+ baseline_calls .append (True )
604+ # Return the expected result using PyTorch operations
605+ return a + b
606+
607+ @helion .kernel (
608+ configs = [config1 , config2 ],
609+ autotune_baseline_fn = custom_baseline ,
610+ autotune_log_level = 0 ,
611+ )
612+ def add (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
613+ out = torch .empty_like (a )
614+ for tile in hl .tile (out .size ()):
615+ out [tile ] = a [tile ] + b [tile ]
616+ return out
617+
618+ args = (
619+ torch .randn ([128 ], device = DEVICE ),
620+ torch .randn ([128 ], device = DEVICE ),
621+ )
622+
623+ # Run autotuning
624+ result = add (* args )
625+
626+ # Verify the custom baseline function was called during autotuning
627+ self .assertGreater (
628+ len (baseline_calls ), 0 , "Custom baseline function should be called"
629+ )
630+
631+ # Verify the result is correct
632+ torch .testing .assert_close (result , args [0 ] + args [1 ])
633+
634+ def test_autotune_baseline_fn_filters_bad_config (self ) -> None :
635+ """Test that custom baseline function correctly filters incorrect configs."""
636+ bad_config = helion .Config (block_sizes = [1 ], num_warps = 8 )
637+ good_config = helion .Config (block_sizes = [1 ], num_warps = 4 )
638+
639+ def custom_baseline (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor : # noqa: FURB118
640+ # Return the correct expected result
641+ return a + b
642+
643+ @helion .kernel (
644+ configs = [bad_config , good_config ],
645+ autotune_baseline_fn = custom_baseline ,
646+ autotune_log_level = 0 ,
647+ )
648+ def add (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
649+ out = torch .empty_like (a )
650+ for tile in hl .tile (out .size ()):
651+ out [tile ] = a [tile ] + b [tile ]
652+ return out
653+
654+ a = torch .randn ([32 ], device = DEVICE )
655+ b = torch .randn ([32 ], device = DEVICE )
656+ bound_kernel = add .bind ((a , b ))
657+ original_compile = bound_kernel .compile_config
658+ bound_kernel .settings .autotune_precompile = "fork"
659+
660+ # Make bad_config produce wrong output
661+ def make_bad_config_produce_wrong_output (
662+ config : helion .Config , * , allow_print : bool = True
663+ ):
664+ fn = original_compile (config , allow_print = allow_print )
665+ if config == bad_config :
666+ return lambda * fn_args , ** fn_kwargs : fn (* fn_args , ** fn_kwargs ) + 1
667+ return fn
668+
669+ import helion .autotuner .base_search as base_search_module
670+
671+ with patch .object (
672+ bound_kernel ,
673+ "compile_config" ,
674+ side_effect = make_bad_config_produce_wrong_output ,
675+ ):
676+ search = FiniteSearch (
677+ bound_kernel , (a , b ), configs = [bad_config , good_config ]
678+ )
679+ with patch .object (
680+ search ,
681+ "start_precompile_and_check_for_hangs" ,
682+ side_effect = lambda config , fn : base_search_module .PrecompileFuture .skip (
683+ search , config , True
684+ ),
685+ ):
686+ # Bad config should be filtered out by accuracy check
687+ _ , bad_time = search .benchmark (bad_config )
688+ self .assertTrue (math .isinf (bad_time ))
689+ self .assertEqual (search .counters .get ("accuracy_mismatch" , 0 ), 1 )
690+
691+ # Good config should pass accuracy check
692+ search .counters ["accuracy_mismatch" ] = 0
693+ _ , good_time = search .benchmark (good_config )
694+ self .assertFalse (math .isinf (good_time ))
695+ self .assertEqual (search .counters .get ("accuracy_mismatch" , 0 ), 0 )
696+
697+ # Autotuning should select the good config
698+ best = search .autotune ()
699+ self .assertEqual (best , good_config )
700+
701+ def test_autotune_baseline_fn_raises_on_failure (self ) -> None :
702+ """Test that AutotuneError is raised when custom baseline function fails."""
703+ config1 = helion .Config (block_sizes = [32 ], num_warps = 4 )
704+ config2 = helion .Config (block_sizes = [64 ], num_warps = 8 )
705+
706+ def failing_baseline (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
707+ raise RuntimeError ("Baseline computation failed!" )
708+
709+ @helion .kernel (
710+ configs = [config1 , config2 ],
711+ autotune_baseline_fn = failing_baseline ,
712+ autotune_log_level = 0 ,
713+ )
714+ def add (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
715+ out = torch .empty_like (a )
716+ for tile in hl .tile (out .size ()):
717+ out [tile ] = a [tile ] + b [tile ]
718+ return out
719+
720+ args = (
721+ torch .randn ([128 ], device = DEVICE ),
722+ torch .randn ([128 ], device = DEVICE ),
723+ )
724+
725+ # Attempting to run should raise AutotuneError
726+ with self .assertRaisesRegex (
727+ helion .exc .AutotuneError ,
728+ "Custom baseline function failed while computing baseline" ,
729+ ):
730+ add (* args )
731+
594732 def test_max_generations (self ):
595733 """Autotuner max generation respects explicit kwargs then setting override."""
596734
0 commit comments