Skip to content

Commit 37c3e3f

Browse files
authored
[Autotuner] Fix fork-based autotuner to avoid re-initializing CUDA context in subprocess (#981)
1 parent 1ab4208 commit 37c3e3f

File tree

2 files changed

+98
-25
lines changed

2 files changed

+98
-25
lines changed

helion/autotuner/base_search.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -365,13 +365,18 @@ def start_precompile_and_check_for_hangs(
365365
)
366366
process.daemon = True
367367
else:
368+
precompiler = _prepare_precompiler_for_fork(
369+
fn, device_args, config, self.kernel, decorator
370+
)
371+
if precompiler is None:
372+
return PrecompileFuture.skip(self, config, True)
368373
ctx = mp.get_context("fork")
369374
parent_conn, child_conn = ctx.Pipe()
370375
process = cast(
371376
"mp.Process",
372377
ctx.Process(
373378
target=_run_kernel_in_subprocess_fork,
374-
args=(fn, device_args, config, self.kernel, child_conn, decorator),
379+
args=(precompiler, config, self.kernel, child_conn, decorator),
375380
),
376381
)
377382
process.daemon = True
@@ -1209,37 +1214,54 @@ def _run_kernel_in_subprocess_spawn(
12091214
os._exit(status)
12101215

12111216

1212-
def _run_kernel_in_subprocess_fork(
1217+
def _prepare_precompiler_for_fork(
12131218
fn: CompiledConfig,
12141219
args: Sequence[object],
12151220
config: Config,
12161221
kernel: BoundKernel,
1222+
decorator: str,
1223+
) -> Callable[[], None] | None:
1224+
def extract_launcher(
1225+
triton_kernel: object,
1226+
grid: tuple[int, ...],
1227+
*launch_args: object,
1228+
**launch_kwargs: object,
1229+
) -> NoReturn:
1230+
raise _ExtractedLaunchArgs(triton_kernel, grid, launch_args, launch_kwargs)
1231+
1232+
try:
1233+
fn(*tuple(args), _launcher=extract_launcher)
1234+
raise RuntimeError("Expected _ExtractedLaunchArgs to be raised")
1235+
except _ExtractedLaunchArgs as extracted:
1236+
precompiler_factory = make_precompiler(
1237+
cast("Any", extracted.kernel),
1238+
config,
1239+
kernel,
1240+
)
1241+
precompiler = precompiler_factory(*extracted.args, **extracted.kwargs)
1242+
if precompiler is already_compiled:
1243+
return None
1244+
return precompiler
1245+
except Exception:
1246+
log.warning(
1247+
"Helion autotuner precompile error for %s\n\nGenerated Triton code:\n%s",
1248+
decorator,
1249+
kernel.to_triton_code(config),
1250+
exc_info=True,
1251+
)
1252+
raise
1253+
1254+
1255+
def _run_kernel_in_subprocess_fork(
1256+
precompiler: Callable[[], None],
1257+
config: Config,
1258+
kernel: BoundKernel,
12171259
conn: connection.Connection,
12181260
decorator: str,
12191261
) -> None:
12201262
status = 0
12211263
try:
1222-
1223-
def extract_launcher(
1224-
triton_kernel: object,
1225-
grid: tuple[int, ...],
1226-
*launch_args: object,
1227-
**launch_kwargs: object,
1228-
) -> NoReturn:
1229-
raise _ExtractedLaunchArgs(triton_kernel, grid, launch_args, launch_kwargs)
1230-
1231-
try:
1232-
fn(*tuple(args), _launcher=extract_launcher)
1233-
raise RuntimeError("Expected _ExtractedLaunchArgs to be raised")
1234-
except _ExtractedLaunchArgs as extracted:
1235-
precompiler_factory = make_precompiler(
1236-
cast("Any", extracted.kernel),
1237-
config,
1238-
kernel,
1239-
)
1240-
precompiler = precompiler_factory(*extracted.args, **extracted.kwargs)
1241-
if precompiler is not already_compiled:
1242-
precompiler()
1264+
precompiler()
12431265
conn.send({"status": "ok"})
12441266
except Exception as exc:
12451267
status = 1

test/test_autotuner.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from contextlib import nullcontext
66
import logging
77
import math
8+
import multiprocessing as mp
89
import os
910
from pathlib import Path
1011
import pickle
1112
import random
1213
import tempfile
1314
from types import SimpleNamespace
15+
from typing import Callable
1416
import unittest
1517
from unittest import skip
1618
from unittest.mock import patch
@@ -70,14 +72,16 @@ def _autotune(self):
7072

7173

7274
class TestAutotuneIgnoreErrors(TestCase):
73-
def _make_search(self, settings: Settings) -> BaseSearch:
75+
def _make_search(
76+
self, settings: Settings, *, args: tuple[object, ...] = ()
77+
) -> BaseSearch:
7478
search = BaseSearch.__new__(BaseSearch)
7579
search.settings = settings
7680
search.kernel = SimpleNamespace(
7781
format_kernel_decorator=lambda config, s: "decorator",
7882
to_triton_code=lambda config: "code",
7983
)
80-
search.args = ()
84+
search.args = args
8185
search.counters = collections.Counter()
8286
search.log = LambdaLogger(logging.CRITICAL)
8387
search._kernel_mutates_args = False
@@ -126,6 +130,53 @@ def bad_fn(*_args):
126130
self.assertEqual(result, float("inf"))
127131
warn.assert_not_called()
128132

133+
@pytest.mark.skipif(
134+
"fork" not in mp.get_all_start_methods(),
135+
reason="fork start method is unavailable on this platform",
136+
)
137+
def test_fork_precompile_avoids_cuda_reinit(self):
138+
settings = Settings(
139+
autotune_precompile="fork",
140+
autotune_log_level=logging.CRITICAL,
141+
autotune_compile_timeout=5,
142+
)
143+
search = self._make_search(settings, args=("arg0",))
144+
145+
parent_pid = os.getpid()
146+
lazy_calls: list[int] = []
147+
148+
def fake_lazy_init() -> None:
149+
lazy_calls.append(os.getpid())
150+
151+
def fake_make_precompiler(_kernel_obj, _config, _bound_kernel):
152+
def binder(*_args: object, **_kwargs: object):
153+
def run() -> None:
154+
return None
155+
156+
return run
157+
158+
return binder
159+
160+
def fake_compiled_fn(
161+
*fn_args: object, _launcher: Callable[..., object]
162+
) -> None:
163+
torch.cuda._lazy_init()
164+
_launcher("fake_kernel", (1,), *fn_args)
165+
166+
with (
167+
patch(
168+
"helion.autotuner.base_search.make_precompiler",
169+
side_effect=fake_make_precompiler,
170+
),
171+
patch("torch.cuda._lazy_init", side_effect=fake_lazy_init),
172+
):
173+
future = search.start_precompile_and_check_for_hangs(
174+
"cfg", fake_compiled_fn
175+
)
176+
self.assertTrue(future())
177+
178+
self.assertEqual(set(lazy_calls), {parent_pid})
179+
129180

130181
class TestAutotuner(RefEagerTestDisabled, TestCase):
131182
def setUp(self):

0 commit comments

Comments
 (0)