Skip to content

Commit f04abf5

Browse files
committed
Add ptxas_config autotuning option
stack-info: PR: #793, branch: jansel/stack/158
1 parent 2ff6426 commit f04abf5

File tree

11 files changed

+169
-3
lines changed

11 files changed

+169
-3
lines changed

helion/_compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ def warps_to_threads(num_warps: int) -> int:
8888
)
8989
return num_warps * (props.warp_size or 32)
9090
return num_warps * 32
91+
92+
93+
def supports_ptxas(device: torch.device) -> bool:
94+
"""Return True if PTXAS controls are supported for the given device."""
95+
if device.type != "cuda":
96+
return False
97+
if torch.version.hip is not None:
98+
return False
99+
return supports_tensor_descriptor()

helion/_compiler/compile_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2121

2222
from .. import exc
23+
from .._compat import supports_ptxas
2324
from ..language.constexpr import ConstExpr
2425
from .loop_dependency_checker import LoopDependencyChecker
2526
from .source_location import SourceLocation
@@ -90,6 +91,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9091
self.block_sizes: list[BlockSizeInfo] = []
9192
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
9293
self.config_spec = ConfigSpec()
94+
self.config_spec.ptxas_supported = supports_ptxas(device)
9395
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
9496
collections.Counter()
9597
)

helion/_compiler/device_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,12 @@ def codegen_function_call(self) -> ast.AST:
573573
f"num_stages={self.config.num_stages}",
574574
]
575575
)
576+
ptxas_config = self.config.ptxas_config
577+
if ptxas_config:
578+
from ..runtime.ptxas_configs import get_ptxas_option
579+
580+
ptx_option = get_ptxas_option(ptxas_config)
581+
args.append(f"ptx_options={ptx_option!r}")
576582
pid = self.pid
577583
assert pid is not None
578584
# TODO(jansel): we should run CSE this statement

helion/autotuner/config_spec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"num_stages",
5151
"pid_type",
5252
"indexing",
53+
"ptxas_config",
5354
]
5455
)
5556
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
@@ -97,6 +98,7 @@ class ConfigSpec:
9798
default_factory=functools.partial(tuple, VALID_PID_TYPES)
9899
)
99100
grid_block_ids: list[int] = dataclasses.field(default_factory=list)
101+
ptxas_supported: bool = True
100102

101103
@staticmethod
102104
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -226,6 +228,11 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
226228
else:
227229
config[name] = values[0]
228230

231+
if self.ptxas_supported:
232+
value = config.get("ptxas_config") or 0
233+
if not isinstance(value, int):
234+
raise InvalidConfig(f"ptxas_config must be integer, got {value!r}")
235+
229236
# Set default values for grid indices when pid_type is not persistent
230237
pid_type = config["pid_type"]
231238
if pid_type in ("flat", "xyz") and self.grid_block_ids:
@@ -267,6 +274,10 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
267274
"indexing": fn(EnumFragment(self._valid_indexing_types())),
268275
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
269276
}
277+
if self.ptxas_supported:
278+
from ..runtime.ptxas_configs import search_ptxas_configs
279+
280+
config["ptxas_config"] = fn(EnumFragment((0, *search_ptxas_configs())))
270281
# Add tunable parameters
271282
for key, fragment in self.user_defined_tunables.items():
272283
config[key] = fn(fragment)

helion/runtime/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,15 @@ def default_launcher(
6161
*args: object,
6262
num_warps: int,
6363
num_stages: int,
64+
ptx_options: str | None = None,
6465
) -> object:
6566
"""Default launcher function that executes the kernel immediately."""
66-
return triton_kernel.run(
67-
*args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages
68-
)
67+
run_kwargs = {
68+
"grid": grid,
69+
"warmup": False,
70+
"num_warps": num_warps,
71+
"num_stages": num_stages,
72+
}
73+
if ptx_options:
74+
run_kwargs["ptx_options"] = ptx_options
75+
return triton_kernel.run(*args, **run_kwargs)

helion/runtime/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
num_stages: int | None = None,
3939
pid_type: PidTypeLiteral | None = None,
4040
indexing: IndexingLiteral | None = None,
41+
ptxas_config: int | None = None,
4142
# For user-defined properties
4243
**kwargs: object,
4344
) -> None:
@@ -78,6 +79,7 @@ def __init__(
7879
"num_stages": num_stages,
7980
"indexing": indexing,
8081
"pid_type": pid_type,
82+
"ptxas_config": ptxas_config,
8183
}
8284
for key, value in core_props.items():
8385
if value is not None:
@@ -169,6 +171,10 @@ def pid_type(self) -> PidTypeLiteral:
169171
def range_unroll_factors(self) -> list[int]:
170172
return cast("list[int]", self.config.get("range_unroll_factors", []))
171173

174+
@property
175+
def ptxas_config(self) -> int:
176+
return cast("int", self.config.get("ptxas_config", 0))
177+
172178
@property
173179
def range_warp_specializes(self) -> list[bool | None]:
174180
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Utilities for working with packaged PTXAS control files."""
2+
3+
from __future__ import annotations
4+
5+
from functools import cache
6+
from pathlib import Path
7+
8+
_CONFIG_FILES: dict[int, str] = {
9+
1: "spiffy-bee-104.bin",
10+
}
11+
12+
13+
def _config_root() -> Path:
14+
return Path(__file__).resolve().parent
15+
16+
17+
@cache
18+
def search_ptxas_configs() -> tuple[int, ...]:
19+
"""Return the sorted tuple of available PTXAS config IDs."""
20+
21+
return tuple(sorted(_CONFIG_FILES))
22+
23+
24+
def _config_file_path(config_id: int) -> str:
25+
"""Return the absolute path to the PTXAS control file for ``config_id``."""
26+
27+
try:
28+
filename = _CONFIG_FILES[config_id]
29+
except KeyError as exc: # pragma: no cover - defensive
30+
raise ValueError(f"Unknown PTXAS config id: {config_id}") from exc
31+
resolved = (_config_root() / filename).resolve()
32+
if not resolved.is_file():
33+
raise FileNotFoundError(f"Missing PTXAS config file: {resolved}")
34+
return str(resolved)
35+
36+
37+
@cache
38+
def get_ptxas_option(config_value: int) -> str | None:
39+
"""Translate a config enum value into a PTXAS option string."""
40+
41+
if config_value == 0:
42+
return None
43+
return f"--apply-controls {_config_file_path(config_value)}"
8.78 KB
Binary file not shown.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ packages = ["helion"]
8585
include = [
8686
"helion/**/*.py",
8787
"helion/**/*.pyi",
88+
"helion/runtime/ptxas_configs/*.bin",
8889
"LICENSE",
8990
]
9091
exclude = [

test/test_ptxas_config.expected

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_ptxas_config.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestPtxasConfig.test_ptxas_config_apply_controls_flag)
5+
from __future__ import annotations
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
from helion.runtime import default_launcher as _default_launcher
11+
12+
@triton.jit
13+
def _helion__copy_kernel(x_flat, out_flat, x_size_0, out_flat_stride_0, x_flat_stride_0, _BLOCK_SIZE_0: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0 * _BLOCK_SIZE_0
16+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
17+
mask_0 = indices_0 < x_size_0
18+
load = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
19+
tl.store(out_flat + indices_0 * out_flat_stride_0, load, mask_0)
20+
21+
def _copy_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
22+
out = torch.empty_like(x)
23+
x_flat = x.view(-1)
24+
out_flat = out.view(-1)
25+
_BLOCK_SIZE_0 = 32
26+
_launcher(_helion__copy_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x_flat, out_flat, x.size(0), out_flat.stride(0), x_flat.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3, ptx_options='--apply-controls <path>')
27+
return out

0 commit comments

Comments
 (0)