Skip to content

Commit 3e336b1

Browse files
committed
Add epilogue subtiling
stack-info: PR: #948, branch: PaulZhang12/stack/14
1 parent efc520e commit 3e336b1

File tree

12 files changed

+614
-41
lines changed

12 files changed

+614
-41
lines changed

examples/matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"range_unroll_factors": [0, 0],
3535
"range_num_stages": [0, 0],
3636
},
37+
allow_epilogue_subtiling=True,
3738
)
3839
def matmul(
3940
x: Tensor,

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
9999
self.device_load_count = (
100100
0 # Track number of loads in all device code for eviction policy tuning
101101
)
102+
self.device_store_count = 0 # Track number of stores for subtiling
102103

103104
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
104105
from .device_function import contains_only_block_size_symbols

helion/_compiler/device_function.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
250250
self.rng_seed_count = 0
251251
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
252252
# Name of the RNG seed buffer parameter in kernel signature
253+
self.device_store_index = (
254+
0 # Track which store in device code we're generating (for subtiling)
255+
)
253256
self.rng_seed_buffer_param_name = None
254257

255258
def has_rng_ops(self) -> bool:
@@ -421,8 +424,9 @@ def tensor_descriptor_arg(
421424
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
422425
) -> TensorDescriptorArg:
423426
host_function = HostFunction.current()
424-
block_size_expr = ", ".join(map(self.literal_expr, block_size))
427+
block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size)
425428
key = (fake_value, block_size_expr)
429+
426430
if key not in self._tensor_descriptor_args:
427431
origin = host_function.tensor_to_origin[fake_value]
428432
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
@@ -515,22 +519,6 @@ def _format_constexpr_value(self, value: object) -> str:
515519
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
516520
value = value._sympy_()
517521

518-
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
519-
if isinstance(value, sympy.Expr):
520-
sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue]
521-
lambda node: isinstance(node, sympy.Function)
522-
and getattr(node.func, "__name__", "")
523-
== "triton_helpers.div_floor_integer",
524-
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
525-
).replace( # pyright: ignore[reportAttributeAccessIssue]
526-
lambda node: isinstance(node, sympy.Function)
527-
and getattr(node.func, "__name__", "")
528-
== "triton_helpers.remainder_integer",
529-
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
530-
)
531-
expr = cast("sympy.Expr", sanitized)
532-
return HostFunction.current().sympy_expr(expr)
533-
534522
return HostFunction.current().literal_expr(value)
535523

536524
def _tensor_property(
@@ -708,11 +696,19 @@ def current() -> DeviceFunction:
708696

709697

710698
class HelionTritonPrinter(TritonPrinter):
711-
"""Custom Triton printer that avoids wrapping float literals in tl.full().
712-
713-
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
714-
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
715-
literal, letting downstream type promotion and casts handle dtype.
699+
"""Custom Triton printer that does the following:
700+
701+
- Avoids wrapping float literals in tl.full().
702+
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
703+
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
704+
literal, letting downstream type promotion and casts handle dtype.
705+
706+
- Avoids triton_helpers.div_floor_integer(...) calls when both operands are
707+
provably non-negative integers. TritonPrinter by default converts
708+
floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
709+
emit u1 // 2 only when the numerator is known to be non-negative and the
710+
denominator is a positive integer, so that we keep helper calls for cases
711+
that rely on floor semantics with mixed signs.
716712
"""
717713

718714
def _print_Float(self, expr: sympy.Expr) -> str:
@@ -721,6 +717,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
721717
def _print_ToFloat(self, expr: sympy.Expr) -> str:
722718
return f"{expr} + 0.0"
723719

720+
def _is_nonnegative(self, expr: sympy.Expr) -> bool:
721+
if expr.is_nonnegative is True or expr.is_zero is True:
722+
return True
723+
if expr.is_positive is True:
724+
return True
725+
try:
726+
host_fn = HostFunction.current()
727+
except NoCurrentFunction:
728+
host_fn = None
729+
if host_fn is not None:
730+
origin_info = host_fn.expr_to_origin.get(expr)
731+
if origin_info and isinstance(
732+
origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin)
733+
):
734+
return True
735+
if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"):
736+
return True
737+
if isinstance(expr, sympy.Number):
738+
return bool(expr >= 0)
739+
return False
740+
741+
def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str:
742+
lhs_str = self._print(lhs)
743+
rhs_str = self._print(rhs)
744+
if not (lhs.is_Integer or lhs.is_Symbol):
745+
lhs_str = f"({lhs_str})"
746+
if not (rhs.is_Integer or rhs.is_Symbol):
747+
rhs_str = f"({rhs_str})"
748+
return f"{lhs_str} // {rhs_str}"
749+
750+
def _print_floor(self, expr: sympy.Expr) -> str:
751+
inner = expr.args[0]
752+
numer, denom = inner.as_numer_denom()
753+
if (
754+
isinstance(denom, sympy.Integer)
755+
and denom > 1
756+
and self._is_nonnegative(numer)
757+
):
758+
return self._format_trunc_div(numer, denom)
759+
return super()._print_floor(expr)
760+
761+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
762+
lhs, rhs = expr.args
763+
if isinstance(rhs, sympy.Integer) and rhs > 0 and self._is_nonnegative(lhs):
764+
return self._format_trunc_div(lhs, rhs)
765+
return super()._print_FloorDiv(expr)
766+
724767

725768
def texpr(expr: sympy.Expr) -> str:
726769
return HelionTritonPrinter().doprint(expr)

helion/_compiler/device_ir.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .type_propagation import _eval_binary
6464
from .type_propagation import _eval_compare
6565
from .type_propagation import _eval_unary
66+
from .utils import _allow_epilogue_subtiling
6667

6768
if TYPE_CHECKING:
6869
from collections.abc import Callable
@@ -1076,7 +1077,7 @@ def visit_For(self, node: ast.For) -> None:
10761077
self.generic_visit(node)
10771078

10781079

1079-
def _count_device_loads(device_ir: DeviceIR) -> int:
1080+
def _count_device_loads_and_stores(device_ir: DeviceIR) -> int:
10801081
"""Count the number of load operations in all device code for eviction policy tuning."""
10811082
from ..language import memory_ops
10821083

@@ -1087,26 +1088,29 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
10871088
if info.new_graph_id is not None
10881089
}
10891090

1090-
load_count = 0
1091+
load_count, store_count = 0, 0
10911092
# Walk all graphs except rolled duplicates
10921093
for graph_info in device_ir.graphs:
10931094
if graph_info.graph_id in rolled_graph_ids:
10941095
continue
10951096

10961097
for node in graph_info.graph.nodes:
10971098
# Check if this is a load operation
1098-
if node.op == "call_function" and node.target is memory_ops.load:
1099-
# Only count loads without explicit eviction policy
1100-
# (user can still specify eviction_policy to override tuning)
1101-
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1102-
eviction_policy_arg = node.kwargs.get("eviction_policy")
1103-
if eviction_policy_arg is None:
1104-
# Check if eviction_policy was passed as positional arg (index 3)
1105-
if len(node.args) >= 4:
1106-
eviction_policy_arg = node.args[3]
1099+
if node.op == "call_function":
1100+
if node.target is memory_ops.load:
1101+
# Only count loads without explicit eviction policy
1102+
# (user can still specify eviction_policy to override tuning)
1103+
# Check kwargs first, then check if 4th arg (eviction_policy) is None
1104+
eviction_policy_arg = node.kwargs.get("eviction_policy")
11071105
if eviction_policy_arg is None:
1108-
load_count += 1
1109-
return load_count
1106+
# Check if eviction_policy was passed as positional arg (index 3)
1107+
if len(node.args) >= 4:
1108+
eviction_policy_arg = node.args[3]
1109+
if eviction_policy_arg is None:
1110+
load_count += 1
1111+
elif node.target is memory_ops.store:
1112+
store_count += 1
1113+
return load_count, store_count
11101114

11111115

11121116
def _register_eviction_policy_tunable(load_count: int) -> None:
@@ -1125,6 +1129,24 @@ def _register_eviction_policy_tunable(load_count: int) -> None:
11251129
env.device_load_count = load_count
11261130

11271131

1132+
def _register_epilogue_subtile_tunable(store_count: int) -> None:
1133+
"""Register the epilogue subtile tunable for all device stores."""
1134+
if store_count == 0:
1135+
return
1136+
1137+
from ..autotuner.config_fragment import EnumFragment
1138+
from ..autotuner.config_fragment import ListOf
1139+
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
1140+
1141+
env = CompileEnvironment.current()
1142+
# Register a tunable for epilogue subtile for all device stores
1143+
fragment = ListOf(
1144+
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
1145+
)
1146+
env.config_spec.epilogue_subtiling = fragment
1147+
env.device_store_count = store_count
1148+
1149+
11281150
def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11291151
device_ir = DeviceIR()
11301152
with func, device_ir, compile_lock:
@@ -1148,9 +1170,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
11481170
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")
11491171

11501172
# Count all device loads and register eviction policy tunable
1151-
load_count = _count_device_loads(device_ir)
1173+
load_count, store_count = _count_device_loads_and_stores(device_ir)
11521174
_register_eviction_policy_tunable(load_count)
11531175

1176+
# Epilogue subtiling only for Blackwell
1177+
if _allow_epilogue_subtiling():
1178+
_register_epilogue_subtile_tunable(store_count)
1179+
11541180
return device_ir
11551181

11561182

0 commit comments

Comments
 (0)