Skip to content

Commit 1c1e282

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent 72fbdca commit 1c1e282

File tree

8 files changed

+168
-17
lines changed

8 files changed

+168
-17
lines changed

examples/matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
@helion.kernel(
2929
# static_shapes=True gives a performance boost for matmuls
3030
static_shapes=True,
31+
autotune_config_overrides={"indexing": "tensor_descriptor"}
3132
)
3233
def matmul(
3334
x: Tensor,

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9999
self.device_load_count = (
100100
0 # Track number of loads in all device code for eviction policy tuning
101101
)
102+
self.device_store_count = (
103+
0 # Track number of stores for subtiling
104+
)
102105

103106
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
104107
from .device_function import contains_only_block_size_symbols

helion/_compiler/device_function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250
self.rng_seed_count = 0
251251
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
252252
# Name of the RNG seed buffer parameter in kernel signature
253+
self.device_store_index = 0 # Track which store in device code we're generating (for subtiling)
253254
self.rng_seed_buffer_param_name = None
254255

255256
def has_rng_ops(self) -> bool:
@@ -420,9 +421,14 @@ def tensor_arg(
420421
def tensor_descriptor_arg(
421422
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422423
) -> TensorDescriptorArg:
424+
import re
423425
host_function = HostFunction.current()
424426
block_size_expr = ", ".join(map(self.literal_expr, block_size))
427+
pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)'
428+
replacement = r'\1 // \2'
429+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
425430
key = (fake_value, block_size_expr)
431+
426432
if key not in self._tensor_descriptor_args:
427433
origin = host_function.tensor_to_origin[fake_value]
428434
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/device_ir.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def visit_For(self, node: ast.For) -> None:
10761076
self.generic_visit(node)
10771077

10781078

1079-
def _count_device_loads(device_ir: DeviceIR) -> int:
1079+
def _count_device_loads_and_stores(device_ir: DeviceIR) -> int:
10801080
"""Count the number of load operations in all device code for eviction policy tuning."""
10811081
from ..language import memory_ops
10821082

@@ -1087,26 +1087,29 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871087
if info.new_graph_id is not None
10881088
}
10891089

1090-
load_count = 0
1090+
load_count, store_count = 0, 0
10911091
# Walk all graphs except rolled duplicates
10921092
for graph_info in device_ir.graphs:
10931093
if graph_info.graph_id in rolled_graph_ids:
10941094
continue
10951095

10961096
for node in graph_info.graph.nodes:
10971097
# Check if this is a load operation
1098-
if node.op == "call_function" and node.target is memory_ops.load:
1099-
# Only count loads without explicit eviction policy
1100-
# (user can still specify eviction_policy to override tuning)
1101-
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1102-
eviction_policy_arg = node.kwargs.get("eviction_policy")
1103-
if eviction_policy_arg is None:
1104-
# Check if eviction_policy was passed as positional arg (index 3)
1105-
if len(node.args) >= 4:
1106-
eviction_policy_arg = node.args[3]
1098+
if node.op == "call_function":
1099+
if node.target is memory_ops.load:
1100+
# Only count loads without explicit eviction policy
1101+
# (user can still specify eviction_policy to override tuning)
1102+
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1103+
eviction_policy_arg = node.kwargs.get("eviction_policy")
11071104
if eviction_policy_arg is None:
1108-
load_count += 1
1109-
return load_count
1105+
# Check if eviction_policy was passed as positional arg (index 3)
1106+
if len(node.args) >= 4:
1107+
eviction_policy_arg = node.args[3]
1108+
if eviction_policy_arg is None:
1109+
load_count += 1
1110+
elif node.target is memory_ops.store:
1111+
store_count += 1
1112+
return load_count, store_count
11101113

11111114

11121115
def _register_eviction_policy_tunable(load_count: int) -> None:
@@ -1124,6 +1127,21 @@ def _register_eviction_policy_tunable(load_count: int) -> None:
11241127
env.config_spec.load_eviction_policies = fragment
11251128
env.device_load_count = load_count
11261129

1130+
def _register_epilogue_subtile_tunable(store_count: int) -> None:
1131+
"""Register the epilogue subtile tunable for all device stores."""
1132+
if store_count == 0:
1133+
return
1134+
1135+
from ..autotuner.config_fragment import EnumFragment
1136+
from ..autotuner.config_fragment import ListOf
1137+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1138+
1139+
env = CompileEnvironment.current()
1140+
# Register a tunable for epilogue subtile for all device stores
1141+
fragment = ListOf(EnumFragment(VALID_EPILOGUE_SUBTILE_SIZES), length=store_count)
1142+
env.config_spec.epilogue_subtile = fragment
1143+
env.device_store_count = store_count
1144+
11271145

11281146
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11291147
device_ir = DeviceIR()
@@ -1148,8 +1166,9 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11481166
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
11491167

11501168
# Count all device loads and register eviction policy tunable
1151-
load_count = _count_device_loads(device_ir)
1169+
load_count, store_count = _count_device_loads_and_stores(device_ir)
11521170
_register_eviction_policy_tunable(load_count)
1171+
_register_epilogue_subtile_tunable(store_count)
11531172

11541173
return device_ir
11551174

helion/_compiler/indexing_strategy.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -103,6 +104,12 @@ def _get_tile_with_offset_info(
103104

104105
return None
105106

107+
def _supports_epilogue_subtiling():
108+
env = CompileEnvironment.current()
109+
if env.device.type != "cuda" or not env.settings.allow_epilogue_subtiling:
110+
return False
111+
return torch.cuda.get_device_capability() >= (10, 0)
112+
106113

107114
class IndexingStrategy:
108115
def codegen_load(
@@ -376,6 +383,7 @@ def codegen_store(
376383
subscript: list[object],
377384
value: ast.AST,
378385
extra_mask: ast.AST | None,
386+
epilogue_subtile: int | None,
379387
) -> ast.AST:
380388
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
381389
return PointerIndexingStrategy().codegen_store(
@@ -384,6 +392,10 @@ def codegen_store(
384392
assert extra_mask is None
385393
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386394

395+
config = DeviceFunction.current().config
396+
if _supports_epilogue_subtiling and config.epilogue_subtiling:
397+
return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value)
398+
387399
# Apply permutation to the value being stored if needed
388400
desc_arg = indexing.tensor_descriptor_arg(state)
389401
store_value = indexing.reshape_store(state, value)
@@ -394,12 +406,102 @@ def codegen_store(
394406
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
395407
store_val=store_value,
396408
)
397-
409+
398410
return expr_from_string(
399411
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
400412
value=store_value,
401413
)
402414

415+
def _codegen_epilogue_subtile_store(
416+
self,
417+
state: CodegenState,
418+
fake_tensor: torch.Tensor,
419+
indexing: BlockedSubscriptIndexing,
420+
store_value: ast.AST,
421+
) -> ast.AST | None:
422+
# Currently support 2D tiles without permutations
423+
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2:
424+
return None
425+
426+
env = CompileEnvironment.current()
427+
block_m, block_n = indexing.block_shape
428+
try:
429+
block_n_hint = env.size_hint(block_n)
430+
except Exception:
431+
return None
432+
433+
if block_n_hint % 2 != 0:
434+
return None
435+
436+
device_fn = state.device_function
437+
codegen = state.codegen
438+
439+
block_m_str = device_fn.literal_expr(block_m)
440+
block_n_str = device_fn.literal_expr(block_n)
441+
indexing.block_shape[1] //= 2
442+
desc_arg = indexing.tensor_descriptor_arg(state)
443+
444+
if desc_arg.permutation is not None:
445+
return None
446+
447+
448+
block_n_half_str = f"({block_n_str} // 2)"
449+
450+
# Lift the store value into a temporary variable for reuse
451+
acc_var = codegen.lift(store_value, prefix="acc")
452+
453+
reshape_expr = expr_from_string(
454+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])",
455+
acc=acc_var,
456+
dim_m=expr_from_string(block_m_str),
457+
dim_half=expr_from_string(block_n_half_str),
458+
)
459+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
460+
461+
permute_expr = expr_from_string(
462+
"tl.permute({acc}, [0, 2, 1])",
463+
acc=reshape_var,
464+
)
465+
permute_var = codegen.lift(permute_expr, prefix="acc")
466+
467+
acc0_name = codegen.tmpvar(prefix="acc")
468+
acc1_name = codegen.tmpvar(prefix="acc")
469+
codegen.add_statement(
470+
statement_from_string(
471+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
472+
acc=permute_var,
473+
)
474+
)
475+
acc0 = expr_from_string(acc0_name)
476+
acc1 = expr_from_string(acc1_name)
477+
478+
desc_name = indexing.tensor_descriptor(state)
479+
offset0 = expr_from_string(indexing.offsets[0])
480+
offset1 = expr_from_string(indexing.offsets[1])
481+
482+
# First subtile store
483+
codegen.add_statement(
484+
statement_from_string(
485+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
486+
off0=offset0,
487+
off1=offset1,
488+
value=acc0,
489+
)
490+
)
491+
492+
offset1_shifted = expr_from_string(
493+
"({offset} + {half})",
494+
offset=expr_from_string(indexing.offsets[1]),
495+
half=expr_from_string(block_n_half_str),
496+
)
497+
498+
# Emit second subtile store as the expression returned to the caller
499+
return expr_from_string(
500+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
501+
off0=offset0,
502+
off1=offset1_shifted,
503+
value=acc1,
504+
)
403505

404506
class StackIndexingStrategy:
405507
"""

helion/autotuner/config_spec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"epilogue_subtiling"
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
5859
VALID_EVICTION_POLICIES = ("", "first", "last")
60+
VALID_EPILOGUE_SUBTILE_SIZES = (0, 2, 4)
5961

6062

6163
@dataclasses.dataclass
@@ -105,6 +107,9 @@ class ConfigSpec:
105107
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106108
)
107109
)
110+
epilogue_subtiling: ListOf = dataclasses.field(
111+
default_factory=lambda: ListOf(EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0)
112+
)
108113

109114
@staticmethod
110115
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
@@ -208,6 +213,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
208213
"range_flattens",
209214
"static_ranges",
210215
"load_eviction_policies",
216+
"epilogue_subtiling",
211217
):
212218
if not config.get(name):
213219
config.pop(name, None)
@@ -217,6 +223,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
217223
config.setdefault(
218224
"load_eviction_policies", self.load_eviction_policies.default()
219225
)
226+
config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default())
220227
# TODO(jansel): include num_ctas and max_nreg
221228

222229
for name, values in (
@@ -231,6 +238,9 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
231238
else:
232239
config[name] = values[0]
233240

241+
if config["indexing"] != "tensor_descriptor" or any(block_id < 16 for block_id in config["block_sizes"]):
242+
for i in range(len(config["epilogue_subtiling"])):
243+
config["epilogue_subtiling"][i] = 0
234244
# Set default values for grid indices when pid_type is not persistent
235245
pid_type = config["pid_type"]
236246
if pid_type in ("flat", "xyz") and self.grid_block_ids:
@@ -289,6 +299,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
289299
"indexing": fn(EnumFragment(self._valid_indexing_types())),
290300
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
291301
"load_eviction_policies": fn(self.load_eviction_policies),
302+
"epilogue_subtiling": fn(self.epilogue_subtiling),
292303
}
293304
# Add tunable parameters
294305
config.update(

helion/language/memory_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _(
5555
index: list[object],
5656
value: torch.Tensor | torch.SymInt | float,
5757
extra_mask: torch.Tensor | None = None,
58+
epilogue_subtile: int | None = None,
5859
) -> tuple[
5960
torch.Tensor | tuple,
6061
list[object],
@@ -68,10 +69,10 @@ def _(
6869
index = Tile._tiles_to_sizes(index)
6970

7071
if isinstance(tensor, StackTensor):
71-
return (tuple(tensor), index, value, extra_mask)
72+
return (tuple(tensor), index, value, extra_mask, epilogue_subtile)
7273

7374
if isinstance(tensor, torch.Tensor):
74-
return (tensor, index, value, extra_mask)
75+
return (tensor, index, value, extra_mask, epilogue_subtile)
7576

7677
raise NotImplementedError(f"Cannot store to type: {type(tensor)}")
7778

@@ -82,6 +83,7 @@ def _(
8283
index: list[object],
8384
value: torch.Tensor | torch.SymInt | float,
8485
extra_mask: torch.Tensor | None = None,
86+
epilogue_subtile: int | None = None,
8587
) -> None:
8688
return None
8789

@@ -93,6 +95,7 @@ def _(state: CodegenState) -> ast.AST:
9395
assert isinstance(subscript, (list, tuple))
9496
value = state.ast_arg(2)
9597
extra_mask = state.ast_args[3]
98+
import pdb; pdb.set_trace()
9699
assert isinstance(extra_mask, (type(None), ast.AST))
97100

98101
if isinstance(tensor, torch.Tensor):

helion/runtime/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
epilogue_subtiling: list[int] | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -61,6 +62,7 @@ def __init__(
6162
num_stages: Number of stages for software pipelining.
6263
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
6364
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
65+
epilogue_subtiling: Whether to use subtiling for epilogue.
6466
**kwargs: Additional user-defined configuration parameters.
6567
"""
6668
self.config = {}
@@ -81,6 +83,7 @@ def __init__(
8183
"num_stages": num_stages,
8284
"indexing": indexing,
8385
"pid_type": pid_type,
86+
"epilogue_subtiling": epilogue_subtiling,
8487
}
8588
for key, value in core_props.items():
8689
if value is not None:
@@ -206,6 +209,9 @@ def load_eviction_policies(self) -> list[str]:
206209
def indexing(self) -> IndexingLiteral:
207210
return self.config.get("indexing", "pointer") # type: ignore[return-value]
208211

212+
@property
213+
def epilogue_subtiling(self) -> bool:
214+
return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value]
209215

210216
def _to_hashable(x: object) -> object:
211217
if isinstance(x, list):

0 commit comments

Comments
 (0)