Skip to content

Commit 88d46a8

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent 9660804 commit 88d46a8

File tree

6 files changed

+171
-17
lines changed

6 files changed

+171
-17
lines changed

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9999
self.device_load_count = (
100100
0 # Track number of loads in all device code for eviction policy tuning
101101
)
102+
self.device_store_count = 0 # Track number of stores for subtiling
102103

103104
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
104105
from .device_function import contains_only_block_size_symbols

helion/_compiler/device_function.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250
self.rng_seed_count = 0
251251
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
252252
# Name of the RNG seed buffer parameter in kernel signature
253+
self.device_store_index = (
254+
0 # Track which store in device code we're generating (for subtiling)
255+
)
253256
self.rng_seed_buffer_param_name = None
254257

255258
def has_rng_ops(self) -> bool:
@@ -420,9 +423,15 @@ def tensor_arg(
420423
def tensor_descriptor_arg(
421424
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422425
) -> TensorDescriptorArg:
426+
import re
427+
423428
host_function = HostFunction.current()
424429
block_size_expr = ", ".join(map(self.literal_expr, block_size))
430+
pattern = r"triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)"
431+
replacement = r"\1 // \2"
432+
block_size_expr = re.sub(pattern, replacement, block_size_expr)
425433
key = (fake_value, block_size_expr)
434+
426435
if key not in self._tensor_descriptor_args:
427436
origin = host_function.tensor_to_origin[fake_value]
428437
desc_name = self.new_var(origin.suggest_var_name() + "_desc")

helion/_compiler/device_ir.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def visit_For(self, node: ast.For) -> None:
10761076
self.generic_visit(node)
10771077

10781078

1079-
def _count_device_loads(device_ir: DeviceIR) -> int:
1079+
def _count_device_loads_and_stores(device_ir: DeviceIR) -> int:
10801080
"""Count the number of load operations in all device code for eviction policy tuning."""
10811081
from ..language import memory_ops
10821082

@@ -1087,26 +1087,29 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871087
if info.new_graph_id is not None
10881088
}
10891089

1090-
load_count = 0
1090+
load_count, store_count = 0, 0
10911091
# Walk all graphs except rolled duplicates
10921092
for graph_info in device_ir.graphs:
10931093
if graph_info.graph_id in rolled_graph_ids:
10941094
continue
10951095

10961096
for node in graph_info.graph.nodes:
10971097
# Check if this is a load operation
1098-
if node.op == "call_function" and node.target is memory_ops.load:
1099-
# Only count loads without explicit eviction policy
1100-
# (user can still specify eviction_policy to override tuning)
1101-
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1102-
eviction_policy_arg = node.kwargs.get("eviction_policy")
1103-
if eviction_policy_arg is None:
1104-
# Check if eviction_policy was passed as positional arg (index 3)
1105-
if len(node.args) >= 4:
1106-
eviction_policy_arg = node.args[3]
1098+
if node.op == "call_function":
1099+
if node.target is memory_ops.load:
1100+
# Only count loads without explicit eviction policy
1101+
# (user can still specify eviction_policy to override tuning)
1102+
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1103+
eviction_policy_arg = node.kwargs.get("eviction_policy")
11071104
if eviction_policy_arg is None:
1108-
load_count += 1
1109-
return load_count
1105+
# Check if eviction_policy was passed as positional arg (index 3)
1106+
if len(node.args) >= 4:
1107+
eviction_policy_arg = node.args[3]
1108+
if eviction_policy_arg is None:
1109+
load_count += 1
1110+
elif node.target is memory_ops.store:
1111+
store_count += 1
1112+
return load_count, store_count
11101113

11111114

11121115
def _register_eviction_policy_tunable(load_count: int) -> None:
@@ -1125,6 +1128,24 @@ def _register_eviction_policy_tunable(load_count: int) -> None:
11251128
env.device_load_count = load_count
11261129

11271130

1131+
def _register_epilogue_subtile_tunable(store_count: int) -> None:
1132+
"""Register the epilogue subtile tunable for all device stores."""
1133+
if store_count == 0:
1134+
return
1135+
1136+
from ..autotuner.config_fragment import EnumFragment
1137+
from ..autotuner.config_fragment import ListOf
1138+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1139+
1140+
env = CompileEnvironment.current()
1141+
# Register a tunable for epilogue subtile for all device stores
1142+
fragment = ListOf(
1143+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
1144+
)
1145+
env.config_spec.epilogue_subtiling = fragment
1146+
env.device_store_count = store_count
1147+
1148+
11281149
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11291150
device_ir = DeviceIR()
11301151
with func, device_ir, compile_lock:
@@ -1148,8 +1169,9 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11481169
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
11491170

11501171
# Count all device loads and register eviction policy tunable
1151-
load_count = _count_device_loads(device_ir)
1172+
load_count, store_count = _count_device_loads_and_stores(device_ir)
11521173
_register_eviction_policy_tunable(load_count)
1174+
_register_epilogue_subtile_tunable(store_count)
11531175

11541176
return device_ir
11551177

helion/_compiler/indexing_strategy.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
@@ -353,7 +354,6 @@ def codegen_load(
353354
)
354355
assert extra_mask is None
355356
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
356-
357357
# Load from tensor descriptor with permuted offsets
358358
load_expr = expr_from_string(
359359
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"
@@ -383,10 +383,24 @@ def codegen_store(
383383
)
384384
assert extra_mask is None
385385
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386+
store_value = indexing.reshape_store(state, value)
387+
388+
config = DeviceFunction.current().config
389+
epilogue_subtiles = state.config.epilogue_subtiling
390+
if torch.cuda.get_device_capability() >= (9, 0) and (
391+
idx := state.device_function.device_store_index
392+
) < len(epilogue_subtiles):
393+
subtile_split = epilogue_subtiles[idx]
394+
state.device_function.device_store_index += 1
395+
396+
subtile_codegen = self._codegen_epilogue_subtile_store(
397+
state, fake_tensor, indexing, store_value, subtile_split, config
398+
)
399+
if subtile_codegen is not None:
400+
return subtile_codegen
386401

387402
# Apply permutation to the value being stored if needed
388403
desc_arg = indexing.tensor_descriptor_arg(state)
389-
store_value = indexing.reshape_store(state, value)
390404

391405
if desc_arg.permutation is not None:
392406
# Apply permutation to the value
@@ -400,6 +414,95 @@ def codegen_store(
400414
value=store_value,
401415
)
402416

417+
def _codegen_epilogue_subtile_store(
418+
self,
419+
state: CodegenState,
420+
fake_tensor: torch.Tensor,
421+
indexing: BlockedSubscriptIndexing,
422+
store_value: ast.AST,
423+
subtile_split: int,
424+
config: Config,
425+
) -> ast.AST | None:
426+
# Currently support 2D tiles without permutations
427+
if (
428+
len(indexing.block_shape) != 2
429+
or len(indexing.offsets) != 2
430+
or subtile_split == 0
431+
):
432+
return None
433+
434+
env = CompileEnvironment.current()
435+
block_m, block_n = indexing.block_shape
436+
try:
437+
block_n_hint = env.size_hint(block_n)
438+
block_idx = env.get_block_id(block_n)
439+
block_size = env.block_sizes[block_idx].from_config(config)
440+
except Exception:
441+
return None
442+
443+
if block_n_hint % 2 != 0 or block_size <= 16:
444+
return None
445+
446+
device_fn = state.device_function
447+
codegen = state.codegen
448+
449+
block_m_str = device_fn.literal_expr(block_m)
450+
block_n_str = device_fn.literal_expr(block_n)
451+
indexing.block_shape[1] //= subtile_split
452+
453+
# TODO(PaulZhang12): Support more epilogue subtile configs besides 2
454+
block_n_half_str = f"({block_n_str} // {subtile_split})"
455+
456+
# Lift the store value into a temporary variable for reuse
457+
acc_var = codegen.lift(store_value, prefix="acc")
458+
459+
reshape_expr = expr_from_string(
460+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)",
461+
acc=acc_var,
462+
dim_m=expr_from_string(block_m_str),
463+
dim_half=expr_from_string(block_n_half_str),
464+
)
465+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
466+
467+
acc0_name = codegen.tmpvar(prefix="acc")
468+
acc1_name = codegen.tmpvar(prefix="acc")
469+
codegen.add_statement(
470+
statement_from_string(
471+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
472+
acc=reshape_var,
473+
)
474+
)
475+
acc0 = expr_from_string(acc0_name)
476+
acc1 = expr_from_string(acc1_name)
477+
478+
desc_name = indexing.tensor_descriptor(state)
479+
offset0 = expr_from_string(indexing.offsets[0])
480+
offset1 = expr_from_string(indexing.offsets[1])
481+
482+
# First subtile store
483+
codegen.add_statement(
484+
statement_from_string(
485+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
486+
off0=offset0,
487+
off1=offset1,
488+
value=acc0,
489+
)
490+
)
491+
492+
offset1_shifted = expr_from_string(
493+
"({offset} + {half})",
494+
offset=expr_from_string(indexing.offsets[1]),
495+
half=expr_from_string(block_n_half_str),
496+
)
497+
498+
# Emit second subtile store as the expression returned to the caller
499+
return expr_from_string(
500+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
501+
off0=offset0,
502+
off1=offset1_shifted,
503+
value=acc1,
504+
)
505+
403506

404507
class StackIndexingStrategy:
405508
"""

helion/autotuner/config_spec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
"pid_type",
5353
"indexing",
5454
"load_eviction_policies",
55+
"epilogue_subtiling",
5556
]
5657
)
5758
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
5859
VALID_EVICTION_POLICIES = ("", "first", "last")
60+
VALID_EPILOGUE_SUBTILE_SIZES = (0, 2)
5961

6062

6163
@dataclasses.dataclass
@@ -105,11 +107,16 @@ class ConfigSpec:
105107
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
106108
)
107109
)
110+
epilogue_subtiling: ListOf = dataclasses.field(
111+
default_factory=lambda: ListOf(
112+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0
113+
)
114+
)
108115

109116
@staticmethod
110117
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
111118
return (
112-
("pointer", "block_ptr", "tensor_descriptor")
119+
("pointer", "tensor_descriptor")
113120
if supports_tensor_descriptor()
114121
else ("pointer", "block_ptr")
115122
)
@@ -208,6 +215,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
208215
"range_flattens",
209216
"static_ranges",
210217
"load_eviction_policies",
218+
"epilogue_subtiling",
211219
):
212220
if not config.get(name):
213221
config.pop(name, None)
@@ -217,6 +225,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
217225
config.setdefault(
218226
"load_eviction_policies", self.load_eviction_policies.default()
219227
)
228+
config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default())
220229
# TODO(jansel): include num_ctas and max_nreg
221230

222231
for name, values in (
@@ -289,6 +298,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
289298
"indexing": fn(EnumFragment(self._valid_indexing_types())),
290299
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
291300
"load_eviction_policies": fn(self.load_eviction_policies),
301+
"epilogue_subtiling": fn(self.epilogue_subtiling),
292302
}
293303
# Add tunable parameters
294304
config.update(
@@ -307,9 +317,11 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
307317
"range_flattens",
308318
"static_ranges",
309319
"load_eviction_policies",
320+
"epilogue_subtiling",
310321
):
311322
if not config.get(name):
312323
config.pop(name, None)
324+
313325
self.normalize(config)
314326
return helion.Config(**config)
315327

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
num_stages: int | None = None,
4040
pid_type: PidTypeLiteral | None = None,
4141
indexing: IndexingLiteral | None = None,
42+
epilogue_subtiling: list[int] | None = None,
4243
# For user-defined properties
4344
**kwargs: object,
4445
) -> None:
@@ -61,6 +62,7 @@ def __init__(
6162
num_stages: Number of stages for software pipelining.
6263
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
6364
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
65+
epilogue_subtiling: Whether to use subtiling for epilogue.
6466
**kwargs: Additional user-defined configuration parameters.
6567
"""
6668
self.config = {}
@@ -81,6 +83,7 @@ def __init__(
8183
"num_stages": num_stages,
8284
"indexing": indexing,
8385
"pid_type": pid_type,
86+
"epilogue_subtiling": epilogue_subtiling,
8487
}
8588
for key, value in core_props.items():
8689
if value is not None:
@@ -206,6 +209,10 @@ def load_eviction_policies(self) -> list[str]:
206209
def indexing(self) -> IndexingLiteral:
207210
return self.config.get("indexing", "pointer") # type: ignore[return-value]
208211

212+
@property
213+
def epilogue_subtiling(self) -> bool:
214+
return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value]
215+
209216

210217
def _to_hashable(x: object) -> object:
211218
if isinstance(x, list):

0 commit comments

Comments
 (0)