From 542bc80f5947ff8458d88a8231dde6877ca31921 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 25 Oct 2025 01:20:48 -0700 Subject: [PATCH] Disallow both interpret modes active --- helion/exc.py | 4 ++++ helion/runtime/settings.py | 10 ++++++++-- test/test_breakpoint.py | 22 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/helion/exc.py b/helion/exc.py index 6f5c02358..de002c026 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -269,6 +269,10 @@ class BreakpointInDeviceLoopRequiresInterpret(BaseError): message = "breakpoint() inside an `hl.tile` or `hl.grid` loop requires TRITON_INTERPRET=1 or HELION_INTERPRET=1." +class BothInterpretModesActive(BaseError): + message = "Cannot have both TRITON_INTERPRET=1 and HELION_INTERPRET=1 active simultaneously. Please use only one interpret mode." + + class UndefinedVariable(BaseError): message = "{} is not defined." diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 48aa1e97c..583f4de3e 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -232,8 +232,14 @@ def _get_autotune_random_seed() -> int: def _get_ref_mode() -> RefMode: - interpret = _env_get_bool("HELION_INTERPRET", False) - return RefMode.EAGER if interpret else RefMode.OFF + triton_interpret = os.environ.get("TRITON_INTERPRET") == "1" + helion_interpret = _env_get_bool("HELION_INTERPRET", False) + + # Ban having both interpret modes active + if triton_interpret and helion_interpret: + raise exc.BothInterpretModesActive + + return RefMode.EAGER if helion_interpret else RefMode.OFF @dataclasses.dataclass diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index c0e4e5a65..8dc5a527c 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -89,6 +89,9 @@ def _run_breakpoint_in_subprocess( ) env = os.environ.copy() + # Set the env vars in the subprocess environment (needed for the ban check) + env["TRITON_INTERPRET"] = str(triton_interpret) + env["HELION_INTERPRET"] = str(helion_interpret) result = subprocess.run( [sys.executable, "-c", script], env=env, @@ -128,6 +131,16 @@ def _run_device_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + def _run_device_breakpoint_both_interpret_test( + self, triton_interpret: int, helion_interpret: int + ) -> None: + """Test that having both interpret modes active is banned.""" + # Environment variables are already set by subprocess, just verify they're both 1 + assert triton_interpret == 1 and helion_interpret == 1 + # When both are set to 1, creating a kernel should raise an error + with self.assertRaises(exc.BothInterpretModesActive): + self._make_device_breakpoint_kernel() + def test_device_breakpoint_no_interpret(self) -> None: self._run_breakpoint_in_subprocess( test_name=self._testMethodName, @@ -152,6 +165,15 @@ def test_device_breakpoint_helion_interpret(self) -> None: helion_interpret=1, ) + def test_device_breakpoint_both_interpret_banned(self) -> None: + """Test that having both TRITON_INTERPRET and HELION_INTERPRET active is banned.""" + self._run_breakpoint_in_subprocess( + test_name=self._testMethodName, + runner_method="_run_device_breakpoint_both_interpret_test", + triton_interpret=1, + helion_interpret=1, + ) + def _run_host_breakpoint_test( self, triton_interpret: int, helion_interpret: int ) -> None: