Skip to content

Commit 3ae89e1

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent 2cb6d17 commit 3ae89e1

File tree

14 files changed

+925
-92
lines changed

14 files changed

+925
-92
lines changed

examples/matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
"range_unroll_factors": [0, 0],
3535
"range_num_stages": [0, 0],
3636
},
37+
allow_epilogue_subtiling=True,
38+
autotune_effort="quick",
3739
)
3840
def matmul(
3941
x: Tensor,

helion/_compiler/device_function.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,9 @@ def tensor_descriptor_arg(
462462
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
463463
) -> TensorDescriptorArg:
464464
host_function = HostFunction.current()
465-
block_size_expr = ", ".join(map(self.literal_expr, block_size))
465+
block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size)
466466
key = (fake_value, block_size_expr)
467+
467468
if key not in self._tensor_descriptor_args:
468469
origin = host_function.tensor_to_origin[fake_value]
469470
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
@@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str:
556557
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
557558
value = value._sympy_()
558559

559-
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
560-
if isinstance(value, sympy.Expr):
561-
sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue]
562-
lambda node: isinstance(node, sympy.Function)
563-
and getattr(node.func, "__name__", "")
564-
== "triton_helpers.div_floor_integer",
565-
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
566-
).replace( # pyright: ignore[reportAttributeAccessIssue]
567-
lambda node: isinstance(node, sympy.Function)
568-
and getattr(node.func, "__name__", "")
569-
== "triton_helpers.remainder_integer",
570-
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
571-
)
572-
expr = cast("sympy.Expr", sanitized)
573-
return HostFunction.current().sympy_expr(expr)
574-
575560
return HostFunction.current().literal_expr(value)
576561

577562
def _tensor_property(
@@ -749,11 +734,19 @@ def current() -> DeviceFunction:
749734

750735

751736
class HelionTritonPrinter(TritonPrinter):
752-
"""Custom Triton printer that avoids wrapping float literals in tl.full().
753-
754-
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
755-
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
756-
literal, letting downstream type promotion and casts handle dtype.
737+
"""Custom Triton printer that does the following:
738+
739+
- Avoids wrapping float literals in tl.full().
740+
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
741+
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
742+
literal, letting downstream type promotion and casts handle dtype.
743+
744+
- Avoids triton_helpers.div_floor_integer(...) calls when both operands are
745+
provably non-negative integers. TritonPrinter by default converts
746+
floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
747+
emit u1 // 2 only when the numerator is known to be non-negative and the
748+
denominator is a positive integer, so that we keep helper calls for cases
749+
that rely on floor semantics with mixed signs.
757750
"""
758751

759752
def _print_Float(self, expr: sympy.Expr) -> str:
@@ -762,6 +755,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
762755
def _print_ToFloat(self, expr: sympy.Expr) -> str:
763756
return f"{expr} + 0.0"
764757

758+
def _is_nonnegative(self, expr: sympy.Expr) -> bool:
759+
if expr.is_nonnegative is True or expr.is_zero is True:
760+
return True
761+
if expr.is_positive is True:
762+
return True
763+
try:
764+
host_fn = HostFunction.current()
765+
except NoCurrentFunction:
766+
host_fn = None
767+
if host_fn is not None:
768+
origin_info = host_fn.expr_to_origin.get(expr)
769+
if origin_info and isinstance(
770+
origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin)
771+
):
772+
return True
773+
if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"):
774+
return True
775+
if isinstance(expr, sympy.Number):
776+
return bool(expr >= 0)
777+
return False
778+
779+
def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str:
780+
lhs_str = self._print(lhs)
781+
rhs_str = self._print(rhs)
782+
if not (lhs.is_Integer or lhs.is_Symbol):
783+
lhs_str = f"({lhs_str})"
784+
if not (rhs.is_Integer or rhs.is_Symbol):
785+
rhs_str = f"({rhs_str})"
786+
return f"{lhs_str} // {rhs_str}"
787+
788+
def _print_floor(self, expr: sympy.Expr) -> str:
789+
inner = expr.args[0]
790+
numer, denom = inner.as_numer_denom()
791+
if (
792+
isinstance(denom, sympy.Integer)
793+
and denom > 1
794+
and self._is_nonnegative(numer)
795+
):
796+
return self._format_trunc_div(numer, denom)
797+
return super()._print_floor(expr)
798+
799+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
800+
lhs, rhs = expr.args
801+
if isinstance(rhs, sympy.Integer) and rhs > 0 and self._is_nonnegative(lhs):
802+
return self._format_trunc_div(lhs, rhs)
803+
return super()._print_FloorDiv(expr)
804+
765805

766806
def texpr(expr: sympy.Expr) -> str:
767807
return HelionTritonPrinter().doprint(expr)

helion/_compiler/device_ir.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .type_propagation import _eval_binary
6464
from .type_propagation import _eval_compare
6565
from .type_propagation import _eval_unary
66+
from .utils import _allow_epilogue_subtiling
6667

6768
if TYPE_CHECKING:
6869
from collections.abc import Callable
@@ -1161,6 +1162,23 @@ def _register_load_store_tunables(
11611162
)
11621163

11631164

1165+
def _register_epilogue_subtile_tunable(store_count: int) -> None:
1166+
"""Register the epilogue subtile tunable for all device stores."""
1167+
if store_count == 0:
1168+
return
1169+
1170+
from ..autotuner.config_fragment import EnumFragment
1171+
from ..autotuner.config_fragment import ListOf
1172+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1173+
1174+
env = CompileEnvironment.current()
1175+
# Register a tunable for epilogue subtile for all device stores
1176+
fragment = ListOf(
1177+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
1178+
)
1179+
env.config_spec.epilogue_subtiling = fragment
1180+
1181+
11641182
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11651183
device_ir = DeviceIR()
11661184
with func, device_ir, compile_lock:
@@ -1191,6 +1209,10 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11911209
total_load_count, loads_without_eviction_policy, store_count
11921210
)
11931211

1212+
# Epilogue subtiling only for Blackwell
1213+
if _allow_epilogue_subtiling():
1214+
_register_epilogue_subtile_tunable(store_count)
1215+
11941216
return device_ir
11951217

11961218

helion/_compiler/indexing_strategy.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from .. import exc
1616
from .._compat import get_tensor_descriptor_fn_name
1717
from .ast_extension import expr_from_string
18+
from .ast_extension import statement_from_string
1819
from .compile_environment import CompileEnvironment
1920
from .device_function import DeviceFunction
2021
from .host_function import HostFunction
2122
from .tile_strategy import DeviceLoopState
23+
from .utils import _allow_epilogue_subtiling
2224
from .utils import compute_slice_size
2325
from .variable_origin import BlockSizeOrigin
2426

@@ -352,7 +354,6 @@ def codegen_load(
352354
)
353355
assert extra_mask is None
354356
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
355-
356357
# Load from tensor descriptor with permuted offsets
357358
load_expr = expr_from_string(
358359
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"
@@ -382,23 +383,188 @@ def codegen_store(
382383
)
383384
assert extra_mask is None
384385
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
386+
store_value = indexing.reshape_store(state, value)
385387

388+
config = DeviceFunction.current().config
389+
epilogue_subtiles = state.config.epilogue_subtiling
386390
# Apply permutation to the value being stored if needed
387391
desc_arg = indexing.tensor_descriptor_arg(state)
388-
store_value = indexing.reshape_store(state, value)
389392

390393
if desc_arg.permutation is not None:
391394
# Apply permutation to the value
392395
store_value = expr_from_string(
393396
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
394397
store_val=store_value,
395398
)
399+
400+
if _allow_epilogue_subtiling() and (
401+
idx := state.device_function.device_store_index
402+
) <= len(epilogue_subtiles):
403+
subtile_split = epilogue_subtiles[idx - 1]
404+
405+
subtile_codegen = self._codegen_epilogue_subtile_store(
406+
state,
407+
fake_tensor,
408+
indexing,
409+
store_value,
410+
subtile_split,
411+
config,
412+
)
413+
if subtile_codegen is not None:
414+
return subtile_codegen
415+
416+
if "pointwise_in" in state.fx_node.meta:
417+
# We still need to codegen pointwise if subtile_codegen is None
418+
store_value = self._apply_pointwise_to_subtile(
419+
state, state.fx_node.meta["pointwise_in"], store_value
420+
)
396421

397422
return expr_from_string(
398423
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
399424
value=store_value,
400425
)
401426

427+
def _apply_pointwise_to_subtile(
428+
self, state: CodegenState, pointwise_node: torch.fx.Node, subtile_value: ast.AST
429+
) -> ast.AST:
430+
"""Apply a pointwise operation to a subtile value.
431+
432+
Args:
433+
state: The codegen state
434+
pointwise_node: The FX node representing the pointwise operation
435+
subtile_value: The AST for the subtile value to apply the operation to
436+
437+
Returns:
438+
AST for the result after applying the pointwise operation
439+
"""
440+
from torch._inductor import ir
441+
442+
from .inductor_lowering import PointwiseLowering
443+
from .inductor_lowering import install_inductor_kernel_handlers
444+
445+
lowering = pointwise_node.meta["lowering"]
446+
assert isinstance(lowering, PointwiseLowering)
447+
448+
# Get the pointwise buffer
449+
buffer = lowering.buffer
450+
assert isinstance(buffer.data, ir.Pointwise)
451+
452+
# Create a temporary variable for the subtile
453+
codegen = state.codegen
454+
subtile_var = codegen.lift(subtile_value, prefix="subtile")
455+
456+
# Set up the inductor kernel handlers with the subtile as input
457+
with install_inductor_kernel_handlers(
458+
codegen, {lowering.input_names[0]: subtile_var}
459+
):
460+
# Generate the pointwise operation
461+
indices = [sympy.Symbol(f"i{n}") for n in range(len(buffer.data.ranges))]
462+
from .inductor_lowering import _unpack_opsvalue
463+
464+
result_name = _unpack_opsvalue(buffer.data.inner_fn(indices))
465+
return expr_from_string(result_name)
466+
467+
def _codegen_epilogue_subtile_store(
468+
self,
469+
state: CodegenState,
470+
fake_tensor: torch.Tensor,
471+
indexing: BlockedSubscriptIndexing,
472+
store_value: ast.AST,
473+
subtile_split: int,
474+
config: Config,
475+
) -> ast.AST | None:
476+
env = CompileEnvironment.current()
477+
block_m, block_n = indexing.block_shape
478+
block_n_hint = env.size_hint(block_n)
479+
block_idx = env.get_block_id(block_n)
480+
block_size = env.block_sizes[block_idx].from_config(config)
481+
482+
if "pointwise_in" in state.fx_node.meta:
483+
fused_pointwise_node = state.fx_node.meta["pointwise_in"]
484+
assert fused_pointwise_node == state.fx_node.args[2]
485+
else:
486+
fused_pointwise_node = None
487+
488+
# Currently support 2D tiles without permutations
489+
if (
490+
len(indexing.block_shape) != 2
491+
or len(indexing.offsets) != 2
492+
or subtile_split == 0
493+
or block_n_hint % 2 != 0
494+
or block_size <= 16
495+
):
496+
return None
497+
498+
device_fn = state.device_function
499+
codegen = state.codegen
500+
501+
block_m_str = device_fn.literal_expr(block_m)
502+
block_n_str = device_fn.literal_expr(block_n)
503+
indexing.block_shape[1] //= subtile_split
504+
505+
# TODO(PaulZhang12): Support more epilogue subtile configs besides 2
506+
block_n_half_str = f"({block_n_str} // {subtile_split})"
507+
508+
# Lift the store value into a temporary variable for reuse
509+
acc_var = codegen.lift(store_value, prefix="acc")
510+
511+
reshape_expr = expr_from_string(
512+
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)",
513+
acc=acc_var,
514+
dim_m=expr_from_string(block_m_str),
515+
dim_half=expr_from_string(block_n_half_str),
516+
)
517+
reshape_var = codegen.lift(reshape_expr, prefix="acc")
518+
519+
acc0_name = codegen.tmpvar(prefix="acc")
520+
acc1_name = codegen.tmpvar(prefix="acc")
521+
codegen.add_statement(
522+
statement_from_string(
523+
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
524+
acc=reshape_var,
525+
)
526+
)
527+
528+
# Now apply the pointwise operation per-subtile if we have one
529+
if fused_pointwise_node is not None:
530+
acc0 = self._apply_pointwise_to_subtile(
531+
state, fused_pointwise_node, expr_from_string(acc0_name)
532+
)
533+
acc1 = self._apply_pointwise_to_subtile(
534+
state, fused_pointwise_node, expr_from_string(acc1_name)
535+
)
536+
else:
537+
acc0 = expr_from_string(acc0_name)
538+
acc1 = expr_from_string(acc1_name)
539+
540+
desc_name = indexing.tensor_descriptor(state)
541+
offset0 = expr_from_string(indexing.offsets[0])
542+
offset1 = expr_from_string(indexing.offsets[1])
543+
544+
# First subtile store
545+
codegen.add_statement(
546+
statement_from_string(
547+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
548+
off0=offset0,
549+
off1=offset1,
550+
value=acc0,
551+
)
552+
)
553+
554+
offset1_shifted = expr_from_string(
555+
"({offset} + {half})",
556+
offset=expr_from_string(indexing.offsets[1]),
557+
half=expr_from_string(block_n_half_str),
558+
)
559+
560+
# Emit second subtile store as the expression returned to the caller
561+
return expr_from_string(
562+
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
563+
off0=offset0,
564+
off1=offset1_shifted,
565+
value=acc1,
566+
)
567+
402568

403569
class StackIndexingStrategy:
404570
"""

0 commit comments

Comments
 (0)