Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
@helion.kernel(
# static_shapes=True gives a performance boost for matmuls
static_shapes=True,
# Disable autotung over unrolling/range_num_stages
# Disable autotuning over range_num_stages
# tl.dot is pipelined with num_stages
autotune_config_overrides={
"range_unroll_factors": [0, 0],
"range_num_stages": [0, 0],
},
allow_epilogue_subtiling=True,
)
def matmul(
x: Tensor,
Expand Down
84 changes: 62 additions & 22 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,9 @@ def tensor_descriptor_arg(
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
) -> TensorDescriptorArg:
host_function = HostFunction.current()
block_size_expr = ", ".join(map(self.literal_expr, block_size))
block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size)
key = (fake_value, block_size_expr)

if key not in self._tensor_descriptor_args:
origin = host_function.tensor_to_origin[fake_value]
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
Expand Down Expand Up @@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str:
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
value = value._sympy_()

# Handle sympy expressions (sanitize by replacing triton_helpers functions)
if isinstance(value, sympy.Expr):
sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue]
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.div_floor_integer",
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
).replace( # pyright: ignore[reportAttributeAccessIssue]
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.remainder_integer",
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
)
expr = cast("sympy.Expr", sanitized)
return HostFunction.current().sympy_expr(expr)

return HostFunction.current().literal_expr(value)

def _tensor_property(
Expand Down Expand Up @@ -749,11 +734,19 @@ def current() -> DeviceFunction:


class HelionTritonPrinter(TritonPrinter):
"""Custom Triton printer that avoids wrapping float literals in tl.full().

Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
literal, letting downstream type promotion and casts handle dtype.
"""Custom Triton printer that does the following:

- Avoids wrapping float literals in tl.full().
Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value
via tl.full([], <val>, tl.float64). We override this to emit the raw numeric
literal, letting downstream type promotion and casts handle dtype.

- Avoids triton_helpers.div_floor_integer(...) calls when both operands are
provably non-negative integers. TritonPrinter by default converts
floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to
emit u1 // 2 only when the numerator is known to be non-negative and the
denominator is a positive integer, so that we keep helper calls for cases
that rely on floor semantics with mixed signs.
"""

def _print_Float(self, expr: sympy.Expr) -> str:
Expand All @@ -762,6 +755,53 @@ def _print_Float(self, expr: sympy.Expr) -> str:
def _print_ToFloat(self, expr: sympy.Expr) -> str:
return f"{expr} + 0.0"

def _is_nonnegative(self, expr: sympy.Expr) -> bool:
if expr.is_nonnegative is True or expr.is_zero is True:
return True
if expr.is_positive is True:
return True
try:
host_fn = HostFunction.current()
except NoCurrentFunction:
host_fn = None
if host_fn is not None:
origin_info = host_fn.expr_to_origin.get(expr)
if origin_info and isinstance(
origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin)
):
return True
if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"):
return True
if isinstance(expr, sympy.Number):
return bool(expr >= 0)
return False

def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str:
lhs_str = self._print(lhs)
rhs_str = self._print(rhs)
if not (lhs.is_Integer or lhs.is_Symbol):
lhs_str = f"({lhs_str})"
if not (rhs.is_Integer or rhs.is_Symbol):
rhs_str = f"({rhs_str})"
return f"{lhs_str} // {rhs_str}"

def _print_floor(self, expr: sympy.Expr) -> str:
inner = expr.args[0]
numer, denom = inner.as_numer_denom()
if (
isinstance(denom, sympy.Integer)
and denom > 1
and self._is_nonnegative(numer)
):
return self._format_trunc_div(numer, denom)
return super()._print_floor(expr)

def _print_FloorDiv(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args
if isinstance(rhs, sympy.Integer) and rhs > 0 and self._is_nonnegative(lhs):
return self._format_trunc_div(lhs, rhs)
return super()._print_FloorDiv(expr)


def texpr(expr: sympy.Expr) -> str:
return HelionTritonPrinter().doprint(expr)
67 changes: 67 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from .type_propagation import _eval_binary
from .type_propagation import _eval_compare
from .type_propagation import _eval_unary
from .utils import _allow_epilogue_subtiling

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -1191,6 +1192,10 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
total_load_count, loads_without_eviction_policy, store_count
)

# Epilogue subtiling only for Blackwell
if _allow_epilogue_subtiling():
epilogue_subtiling_pass(graph.graph, store_count)

return device_ir


Expand Down Expand Up @@ -1348,3 +1353,65 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
user.args = tuple(new_args)
if len(node.users) == 0:
graph.erase_node(node)

def epilogue_subtiling_pass(graph: torch.fx.Graph, store_count: int) -> None:
"""
Replace epilogue subtile with a tunable value.
"""
if store_count == 0:
return

from ..autotuner.config_fragment import EnumFragment
from ..autotuner.config_fragment import ListOf
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES
from .inductor_lowering import PointwiseLowering

env = CompileEnvironment.current()
# Register a tunable for epilogue subtile for all device stores
fragment = ListOf(
EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count
)
env.config_spec.epilogue_subtiling = fragment

def collect_pointwise_epilogue_nodes(store_node: torch.fx.Node):
"""Recursively collect all pointwise nodes that can be subtiled in the epilogue.

Starting from a store node, traverse backwards through all input nodes,
collecting pointwise operations until we hit non-pointwise nodes.
Only include pointwise nodes that have a single user to ensure they can be fused.
"""
# dict to preserve order
pointwise_nodes = dict()
visited = set()
stack = [store_node.args[2]] # Start with the value being stored

while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)

lowering = current.meta.get("lowering")
# Check if this is a pointwise operation with only one user
if isinstance(lowering, PointwiseLowering) and len(current.users) == 1:
if current not in pointwise_nodes:
pointwise_nodes[current] = None
stack.extend(current.all_input_nodes)

return pointwise_nodes


from ..language import store as store_api
stores = set()

for node in graph.nodes:
if node.op == "call_function" and node.target == store_api:
stores.add(node)
# Collect all pointwise nodes that can be subtiled in the epilogue
pointwise_nodes = collect_pointwise_epilogue_nodes(node)
if pointwise_nodes:
# Mark all collected pointwise nodes for epilogue subtiling
for pw_node in pointwise_nodes:
pw_node.meta["epilogue_subtile"] = True
# Store the set of pointwise nodes in the store node's metadata
node.meta["pointwise_epilogue_nodes"] = pointwise_nodes
Loading
Loading