From 0ef154fd62a142a41f8de0d3a232980c13f573e9 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Wed, 15 Oct 2025 13:57:57 -0700 Subject: [PATCH] Add epilogue subtiling stack-info: PR: https://github.com/pytorch/helion/pull/948, branch: PaulZhang12/stack/14 --- examples/matmul.py | 3 +- helion/_compiler/device_function.py | 84 ++++-- helion/_compiler/device_ir.py | 67 +++++ helion/_compiler/indexing_strategy.py | 377 ++++++++++++++++++++++++- helion/_compiler/inductor_lowering.py | 24 ++ helion/_compiler/utils.py | 12 +- helion/autotuner/config_spec.py | 12 + helion/runtime/config.py | 7 + helion/runtime/settings.py | 6 + test/test_examples.expected | 279 ++++++++++++++++++- test/test_examples.py | 1 + test/test_matmul.expected | 378 ++++++++++++++++++++++++++ test/test_matmul.py | 85 ++++-- test/test_type_propagation.expected | 88 +++--- 14 files changed, 1330 insertions(+), 93 deletions(-) diff --git a/examples/matmul.py b/examples/matmul.py index b3c3ca4d8..bc45b17cd 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -28,12 +28,13 @@ @helion.kernel( # static_shapes=True gives a performance boost for matmuls static_shapes=True, - # Disable autotung over unrolling/range_num_stages + # Disable autotuning over range_num_stages # tl.dot is pipelined with num_stages autotune_config_overrides={ "range_unroll_factors": [0, 0], "range_num_stages": [0, 0], }, + allow_epilogue_subtiling=True, ) def matmul( x: Tensor, diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 6cec94000..c7a02fc18 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -462,8 +462,9 @@ def tensor_descriptor_arg( self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt] ) -> TensorDescriptorArg: host_function = HostFunction.current() - block_size_expr = ", ".join(map(self.literal_expr, block_size)) + block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size) key = (fake_value, block_size_expr) + if key not in self._tensor_descriptor_args: origin = host_function.tensor_to_origin[fake_value] desc_name = self.new_var(origin.suggest_var_name() + "_desc") @@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str: if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): value = value._sympy_() - # Handle sympy expressions (sanitize by replacing triton_helpers functions) - if isinstance(value, sympy.Expr): - sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue] - lambda node: isinstance(node, sympy.Function) - and getattr(node.func, "__name__", "") - == "triton_helpers.div_floor_integer", - lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue] - ).replace( # pyright: ignore[reportAttributeAccessIssue] - lambda node: isinstance(node, sympy.Function) - and getattr(node.func, "__name__", "") - == "triton_helpers.remainder_integer", - lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue] - ) - expr = cast("sympy.Expr", sanitized) - return HostFunction.current().sympy_expr(expr) - return HostFunction.current().literal_expr(value) def _tensor_property( @@ -749,11 +734,19 @@ def current() -> DeviceFunction: class HelionTritonPrinter(TritonPrinter): - """Custom Triton printer that avoids wrapping float literals in tl.full(). - - Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value - via tl.full([], , tl.float64). We override this to emit the raw numeric - literal, letting downstream type promotion and casts handle dtype. + """Custom Triton printer that does the following: + + - Avoids wrapping float literals in tl.full(). + Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value + via tl.full([], , tl.float64). We override this to emit the raw numeric + literal, letting downstream type promotion and casts handle dtype. + + - Avoids triton_helpers.div_floor_integer(...) calls when both operands are + provably non-negative integers. TritonPrinter by default converts + floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to + emit u1 // 2 only when the numerator is known to be non-negative and the + denominator is a positive integer, so that we keep helper calls for cases + that rely on floor semantics with mixed signs. """ def _print_Float(self, expr: sympy.Expr) -> str: @@ -762,6 +755,53 @@ def _print_Float(self, expr: sympy.Expr) -> str: def _print_ToFloat(self, expr: sympy.Expr) -> str: return f"{expr} + 0.0" + def _is_nonnegative(self, expr: sympy.Expr) -> bool: + if expr.is_nonnegative is True or expr.is_zero is True: + return True + if expr.is_positive is True: + return True + try: + host_fn = HostFunction.current() + except NoCurrentFunction: + host_fn = None + if host_fn is not None: + origin_info = host_fn.expr_to_origin.get(expr) + if origin_info and isinstance( + origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin) + ): + return True + if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"): + return True + if isinstance(expr, sympy.Number): + return bool(expr >= 0) + return False + + def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str: + lhs_str = self._print(lhs) + rhs_str = self._print(rhs) + if not (lhs.is_Integer or lhs.is_Symbol): + lhs_str = f"({lhs_str})" + if not (rhs.is_Integer or rhs.is_Symbol): + rhs_str = f"({rhs_str})" + return f"{lhs_str} // {rhs_str}" + + def _print_floor(self, expr: sympy.Expr) -> str: + inner = expr.args[0] + numer, denom = inner.as_numer_denom() + if ( + isinstance(denom, sympy.Integer) + and denom > 1 + and self._is_nonnegative(numer) + ): + return self._format_trunc_div(numer, denom) + return super()._print_floor(expr) + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + if isinstance(rhs, sympy.Integer) and rhs > 0 and self._is_nonnegative(lhs): + return self._format_trunc_div(lhs, rhs) + return super()._print_FloorDiv(expr) + def texpr(expr: sympy.Expr) -> str: return HelionTritonPrinter().doprint(expr) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index bb119db86..b73ed6268 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -63,6 +63,7 @@ from .type_propagation import _eval_binary from .type_propagation import _eval_compare from .type_propagation import _eval_unary +from .utils import _allow_epilogue_subtiling if TYPE_CHECKING: from collections.abc import Callable @@ -1191,6 +1192,10 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: total_load_count, loads_without_eviction_policy, store_count ) + # Epilogue subtiling only for Blackwell + if _allow_epilogue_subtiling(): + epilogue_subtiling_pass(graph.graph, store_count) + return device_ir @@ -1348,3 +1353,65 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: user.args = tuple(new_args) if len(node.users) == 0: graph.erase_node(node) + +def epilogue_subtiling_pass(graph: torch.fx.Graph, store_count: int) -> None: + """ + Replace epilogue subtile with a tunable value. + """ + if store_count == 0: + return + + from ..autotuner.config_fragment import EnumFragment + from ..autotuner.config_fragment import ListOf + from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES + from .inductor_lowering import PointwiseLowering + + env = CompileEnvironment.current() + # Register a tunable for epilogue subtile for all device stores + fragment = ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count + ) + env.config_spec.epilogue_subtiling = fragment + + def collect_pointwise_epilogue_nodes(store_node: torch.fx.Node): + """Recursively collect all pointwise nodes that can be subtiled in the epilogue. + + Starting from a store node, traverse backwards through all input nodes, + collecting pointwise operations until we hit non-pointwise nodes. + Only include pointwise nodes that have a single user to ensure they can be fused. + """ + # dict to preserve order + pointwise_nodes = dict() + visited = set() + stack = [store_node.args[2]] # Start with the value being stored + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + + lowering = current.meta.get("lowering") + # Check if this is a pointwise operation with only one user + if isinstance(lowering, PointwiseLowering) and len(current.users) == 1: + if current not in pointwise_nodes: + pointwise_nodes[current] = None + stack.extend(current.all_input_nodes) + + return pointwise_nodes + + + from ..language import store as store_api + stores = set() + + for node in graph.nodes: + if node.op == "call_function" and node.target == store_api: + stores.add(node) + # Collect all pointwise nodes that can be subtiled in the epilogue + pointwise_nodes = collect_pointwise_epilogue_nodes(node) + if pointwise_nodes: + # Mark all collected pointwise nodes for epilogue subtiling + for pw_node in pointwise_nodes: + pw_node.meta["epilogue_subtile"] = True + # Store the set of pointwise nodes in the store node's metadata + node.meta["pointwise_epilogue_nodes"] = pointwise_nodes diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index b6c8eaa66..36e5a2eae 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -15,10 +15,12 @@ from .. import exc from .._compat import get_tensor_descriptor_fn_name from .ast_extension import expr_from_string +from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .device_function import DeviceFunction from .host_function import HostFunction from .tile_strategy import DeviceLoopState +from .utils import _allow_epilogue_subtiling from .utils import compute_slice_size from .variable_origin import BlockSizeOrigin @@ -103,6 +105,46 @@ def _get_tile_with_offset_info( return None +def _apply_pointwise_to_subtile( + state: CodegenState, pointwise_node: torch.fx.Node, subtile_value: ast.AST +) -> ast.AST: + """Apply a pointwise operation to a subtile value. + + Args: + state: The codegen state + pointwise_node: The FX node representing the pointwise operation + subtile_value: The AST for the subtile value to apply the operation to + + Returns: + AST for the result after applying the pointwise operation + """ + from torch._inductor import ir + + from .inductor_lowering import PointwiseLowering + from .inductor_lowering import install_inductor_kernel_handlers + + lowering = pointwise_node.meta["lowering"] + assert isinstance(lowering, PointwiseLowering) + + # Get the pointwise buffer + buffer = lowering.buffer + assert isinstance(buffer.data, ir.Pointwise) + + # Create a temporary variable for the subtile + codegen = state.codegen + subtile_var = codegen.lift(subtile_value, prefix="subtile") + + # Set up the inductor kernel handlers with the subtile as input + with install_inductor_kernel_handlers( + codegen, {lowering.input_names[0]: subtile_var} + ): + # Generate the pointwise operation + indices = [sympy.Symbol(f"i{n}") for n in range(len(buffer.data.ranges))] + from .inductor_lowering import _unpack_opsvalue + + result_name = _unpack_opsvalue(buffer.data.inner_fn(indices)) + return expr_from_string(result_name) + class IndexingStrategy: def codegen_load( @@ -178,6 +220,40 @@ def codegen_store( ) -> ast.AST: indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask) name = state.device_function.tensor_arg(fake_tensor).name + + config = DeviceFunction.current().config + epilogue_subtiles = state.config.epilogue_subtiling + + # Try epilogue subtiling if enabled + if _allow_epilogue_subtiling() and ( + idx := state.device_function.device_store_index + ) <= len(epilogue_subtiles): + subtile_split = epilogue_subtiles[idx - 1] + subtile_codegen = self._codegen_epilogue_subtile_store( + state, + fake_tensor, + indexing, + subscript, + value, + subtile_split, + config, + extra_mask, + ) + if subtile_codegen is not None: + return subtile_codegen + + if "pointwise_epilogue_nodes" in state.fx_node.meta: + # We still need to codegen pointwise if subtile_codegen is None + # Apply all pointwise operations in the epilogue + pointwise_nodes = state.fx_node.meta["pointwise_epilogue_nodes"] + # Apply pointwise operations in topological order (from inputs to outputs) + # The immediate value argument to store is the final pointwise in the chain + store_value_node = state.fx_node.args[2] + if store_value_node in pointwise_nodes: + value = _apply_pointwise_to_subtile( + state, store_value_node, value + ) + return expr_from_string( f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", value=value, @@ -185,6 +261,178 @@ def codegen_store( mask=indexing.mask_expr, ) + def _codegen_epilogue_subtile_store( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: SubscriptIndexing, + subscript: list[object], + store_value: ast.AST, + subtile_split: int, + config: Config, + extra_mask: ast.AST | None, + ) -> ast.AST | None: + """Generate epilogue subtiling for pointer-based stores. + + This splits the store value and offsets to perform multiple smaller stores, + which can improve performance by reducing TMA overhead. + """ + env = CompileEnvironment.current() + + # Get the output shape from SubscriptIndexing + output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript, state) + + # Currently only support 2D tiles + if len(output_shape) != 2 or subtile_split == 0: + return None + + block_m, block_n = output_shape + block_n_hint = env.size_hint(block_n) + block_idx = env.get_block_id(block_n) + + if block_idx is None: + return None + + block_size = env.block_sizes[block_idx].from_config(config) + + # Check if subtiling is feasible + if block_n_hint % 2 != 0 or block_size <= 16: + return None + + device_fn = state.device_function + codegen = state.codegen + + block_m_str = device_fn.literal_expr(block_m) + block_n_str = device_fn.literal_expr(block_n) + + # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 + block_n_half_str = f"({block_n_str} // {subtile_split})" + block_n_half_expr = expr_from_string(block_n_half_str) + + # Lift the store value into a temporary variable for reuse + acc_var = codegen.lift(store_value, prefix="acc") + + # Reshape and split the accumulator + reshape_expr = expr_from_string( + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", + acc=acc_var, + dim_m=expr_from_string(block_m_str), + dim_half=block_n_half_expr, + ) + reshape_var = codegen.lift(reshape_expr, prefix="acc") + + acc0_name = codegen.tmpvar(prefix="acc") + acc1_name = codegen.tmpvar(prefix="acc") + codegen.add_statement( + statement_from_string( + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", + acc=reshape_var, + ) + ) + + # Apply pointwise operations to each subtile if present + acc0 = expr_from_string(acc0_name) + acc1 = expr_from_string(acc1_name) + if "pointwise_epilogue_nodes" in state.fx_node.meta: + pointwise_nodes = list(reversed(state.fx_node.meta["pointwise_epilogue_nodes"])) + for pointwise_node in pointwise_nodes: + acc0 = _apply_pointwise_to_subtile( + state, pointwise_node, acc0 + ) + acc1 = _apply_pointwise_to_subtile( + state, pointwise_node, acc1 + ) + + name = state.device_function.tensor_arg(fake_tensor).name + + # Generate sliced index variables for N dimension + # Get the index variable for the N dimension (block_idx) + offset_n_var = codegen.offset_var(block_idx) + + # Create sliced indices for each subtile + # First subtile: indices_n[:block_n_half] + index_n_0_name = codegen.tmpvar(prefix="indices_n") + codegen.add_statement( + statement_from_string( + f"{index_n_0_name} = ({offset_n_var} + tl.arange(0, {block_n_half_str})).to(tl.int32)" + ) + ) + + # Second subtile: indices_n[block_n_half:] + index_n_1_name = codegen.tmpvar(prefix="indices_n") + codegen.add_statement( + statement_from_string( + f"{index_n_1_name} = ({offset_n_var} + {block_n_half_str} + tl.arange(0, {block_n_half_str})).to(tl.int32)" + ) + ) + + # Reconstruct the offset expressions for each subtile + # We need to replace the N dimension index with the sliced versions + stride_n = state.device_function.tensor_stride(fake_tensor, -1).name + stride_m = state.device_function.tensor_stride(fake_tensor, -2).name + index_m_var = codegen.index_var(env.get_block_id(block_m)) + + # Build offset for first subtile + offset_0 = expr_from_string( + f"{index_m_var}[:, None] * {stride_m} + {index_n_0_name}[None, :] * {stride_n}" + ) + + # Build offset for second subtile (note: need to add block_n_half to base for second half) + offset_1 = expr_from_string( + f"{index_m_var}[:, None] * {stride_m} + {index_n_1_name}[None, :] * {stride_n}" + ) + + # Generate masks for each subtile if masking is needed + mask_0 = indexing.mask_expr + mask_1 = indexing.mask_expr + + if indexing.has_mask(): + # Need to slice the mask as well for N dimension + mask_n_var = codegen.mask_var(block_idx) + if mask_n_var is not None: + # Original mask structure: mask_m[:, None] & mask_n[None, :] + # Need to slice mask_n for each subtile + mask_n_0_name = codegen.tmpvar(prefix="mask_n") + mask_n_1_name = codegen.tmpvar(prefix="mask_n") + + codegen.add_statement( + statement_from_string(f"{mask_n_0_name} = {index_n_0_name} < {stride_m}") + ) + codegen.add_statement( + statement_from_string(f"{mask_n_1_name} = {index_n_1_name} < {stride_m}") + ) + + # Reconstruct masks with sliced components + mask_m_var = codegen.mask_var(env.get_block_id(block_m)) + if mask_m_var is not None: + mask_0 = expr_from_string( + f"{mask_m_var}[:, None] & {mask_n_0_name}[None, :]" + ) + mask_1 = expr_from_string( + f"{mask_m_var}[:, None] & {mask_n_1_name}[None, :]" + ) + else: + mask_0 = expr_from_string(f"{mask_n_0_name}[None, :]") + mask_1 = expr_from_string(f"{mask_n_1_name}[None, :]") + + # First subtile store + codegen.add_statement( + statement_from_string( + f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", + value=acc0, + offset=offset_0, + mask=mask_0, + ) + ) + + # Second subtile store - return as the result + return expr_from_string( + f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", + value=acc1, + offset=offset_1, + mask=mask_1, + ) + class BlockPtrIndexingStrategy(IndexingStrategy): """Use block_ptr to load/store from tensors""" @@ -352,7 +600,6 @@ def codegen_load( ) assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) - # Load from tensor descriptor with permuted offsets load_expr = expr_from_string( f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" @@ -382,10 +629,12 @@ def codegen_store( ) assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) + store_value = indexing.reshape_store(state, value) + config = DeviceFunction.current().config + epilogue_subtiles = state.config.epilogue_subtiling # Apply permutation to the value being stored if needed desc_arg = indexing.tensor_descriptor_arg(state) - store_value = indexing.reshape_store(state, value) if desc_arg.permutation is not None: # Apply permutation to the value @@ -394,12 +643,136 @@ def codegen_store( store_val=store_value, ) + if _allow_epilogue_subtiling() and ( + idx := state.device_function.device_store_index + ) <= len(epilogue_subtiles): + subtile_split = epilogue_subtiles[idx - 1] + subtile_codegen = self._codegen_epilogue_subtile_store( + state, + fake_tensor, + indexing, + store_value, + subtile_split, + config, + ) + if subtile_codegen is not None: + return subtile_codegen + + if "pointwise_epilogue_nodes" in state.fx_node.meta: + # We still need to codegen pointwise if subtile_codegen is None + # Apply all pointwise operations in the epilogue + pointwise_nodes = state.fx_node.meta["pointwise_epilogue_nodes"] + # Apply pointwise operations in topological order (from inputs to outputs) + # The immediate value argument to store is the final pointwise in the chain + store_value_node = state.fx_node.args[2] + if store_value_node in pointwise_nodes: + store_value = _apply_pointwise_to_subtile( + state, store_value_node, store_value + ) + return expr_from_string( f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", value=store_value, ) + def _codegen_epilogue_subtile_store( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: BlockedSubscriptIndexing, + store_value: ast.AST, + subtile_split: int, + config: Config, + ) -> ast.AST | None: + env = CompileEnvironment.current() + block_m, block_n = indexing.block_shape + block_n_hint = env.size_hint(block_n) + block_idx = env.get_block_id(block_n) + block_size = env.block_sizes[block_idx].from_config(config) + + # Currently support 2D tiles without permutations + if ( + len(indexing.block_shape) != 2 + or len(indexing.offsets) != 2 + or subtile_split == 0 + or block_n_hint % 2 != 0 + or block_size <= 16 + ): + return None + + device_fn = state.device_function + codegen = state.codegen + + block_m_str = device_fn.literal_expr(block_m) + block_n_str = device_fn.literal_expr(block_n) + indexing.block_shape[1] //= subtile_split + + # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 + block_n_half_str = f"({block_n_str} // {subtile_split})" + + # Lift the store value into a temporary variable for reuse + acc_var = codegen.lift(store_value, prefix="acc") + + reshape_expr = expr_from_string( + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", + acc=acc_var, + dim_m=expr_from_string(block_m_str), + dim_half=expr_from_string(block_n_half_str), + ) + reshape_var = codegen.lift(reshape_expr, prefix="acc") + + acc0_name = codegen.tmpvar(prefix="acc") + acc1_name = codegen.tmpvar(prefix="acc") + codegen.add_statement( + statement_from_string( + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", + acc=reshape_var, + ) + ) + + acc0 = expr_from_string(acc0_name) + acc1 = expr_from_string(acc1_name) + if "pointwise_epilogue_nodes" in state.fx_node.meta: + pointwise_nodes = list(reversed(state.fx_node.meta["pointwise_epilogue_nodes"])) + for pointwise_node in pointwise_nodes: + acc0 = _apply_pointwise_to_subtile( + state, pointwise_node, acc0 + ) + acc1 = _apply_pointwise_to_subtile( + state, pointwise_node, acc1 + ) + + desc_name = indexing.tensor_descriptor(state) + offset0 = expr_from_string(indexing.offsets[0]) + offset1 = expr_from_string(indexing.offsets[1]) + + # First subtile store + codegen.add_statement( + statement_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1, + value=acc0, + ) + ) + + offset1_shifted = expr_from_string( + "({offset} + {half})", + offset=expr_from_string(indexing.offsets[1]), + half=expr_from_string(block_n_half_str), + ) + + # Emit second subtile store as the expression returned to the caller + return expr_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1_shifted, + value=acc1, + ) + + + class StackIndexingStrategy: """ Generate pointer math for stacking load/store to several device memory pointers sharing the same indexing. diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 001c0e78f..be780d3ab 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -62,6 +62,7 @@ from .node_masking import getitem_masked_value from .node_masking import inductor_masked_value from .node_masking import mask_node_inputs +from .utils import _allow_epilogue_subtiling if TYPE_CHECKING: from collections.abc import Callable @@ -105,6 +106,23 @@ def prepare_graph_lowerings(graph: torch.fx.Graph) -> None: with node.meta["location"]: prepare_node_lowering(graph_lowering, node) + if _allow_epilogue_subtiling(): + from ..language import store as store_api + stores = set() + + for node in reversed(graph.nodes): + if node.op == "call_function" and node.target == store_api: + stores.add(node) + value_node = node.args[2] + # TODO (PaulZhang12): Only support multiple layers of pointwise -> store + lowering = value_node.meta.get("lowering") + if ( + isinstance(lowering, PointwiseLowering) + and len(value_node.users) == 1 + ): + value_node.meta["epilogue_subtile"] = True + node.meta["pointwise_in"] = value_node + def prepare_node_lowering( graph_lowering: GraphLowering, @@ -1397,6 +1415,12 @@ def _collect_multi_outputs( def run_node(self, n: Node) -> object: if n.op == "call_function": + + cfg = self.cg.device_function.config + # Skip codegen for nodes that can potentially be subtiled + if n.meta.get("epilogue_subtile", False): + return self.env[n.args[0]] + with self._set_current_node(n), n.meta["location"], V.set_current_node(n): try: lowering: Lowering = n.meta["lowering"] diff --git a/helion/_compiler/utils.py b/helion/_compiler/utils.py index 0992514af..762349dc5 100644 --- a/helion/_compiler/utils.py +++ b/helion/_compiler/utils.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import torch -if TYPE_CHECKING: - import torch +from .compile_environment import CompileEnvironment def compute_slice_size( @@ -29,3 +28,10 @@ def compute_slice_size( start = slice_obj.start if slice_obj.start is not None else 0 stop = slice_obj.stop if slice_obj.stop is not None else original_size return stop - start + + +def _allow_epilogue_subtiling() -> bool: + return ( + torch.cuda.get_device_capability() >= (10, 0) + and CompileEnvironment.current().settings.allow_epilogue_subtiling + ) diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index bf7cf1270..22ab1c18a 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -52,10 +52,12 @@ "pid_type", "indexing", "load_eviction_policies", + "epilogue_subtiling", ] ) VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved") VALID_EVICTION_POLICIES = ("", "first", "last") +VALID_EPILOGUE_SUBTILE_SIZES = (0, 2) @dataclasses.dataclass @@ -110,6 +112,11 @@ class ConfigSpec: EnumFragment(choices=ConfigSpec._valid_indexing_types()), length=0 ) ) + epilogue_subtiling: ListOf = dataclasses.field( + default_factory=lambda: ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0 + ) + ) @staticmethod def _valid_indexing_types() -> tuple[IndexingLiteral, ...]: @@ -214,6 +221,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "static_ranges", "load_eviction_policies", "indexing", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) @@ -224,6 +232,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "load_eviction_policies", self.load_eviction_policies.default() ) config.setdefault("indexing", self.indexing.default()) + config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default()) # TODO(jansel): include num_ctas and max_nreg for name, values in (("pid_type", VALID_PID_TYPES),): @@ -297,6 +306,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "indexing": fn(self.indexing), "pid_type": fn(EnumFragment(self.allowed_pid_types)), "load_eviction_policies": fn(self.load_eviction_policies), + "epilogue_subtiling": fn(self.epilogue_subtiling), } # Add tunable parameters config.update( @@ -316,9 +326,11 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "static_ranges", "load_eviction_policies", "indexing", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) + self.normalize(config) return helion.Config(**config) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index 1d13d8884..e80c85aac 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -40,6 +40,7 @@ def __init__( num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, indexing: IndexingLiteral | list[IndexingLiteral] | None = None, + epilogue_subtiling: list[int] | None = None, # For user-defined properties **kwargs: object, ) -> None: @@ -68,6 +69,7 @@ def __init__( indexing=["pointer", "block_ptr", "tensor_descriptor"] - Empty/omitted (all loads/stores default to "pointer") Valid strategies: "pointer", "tensor_descriptor", "block_ptr" + epilogue_subtiling: Whether to use subtiling for epilogue. **kwargs: Additional user-defined configuration parameters. """ self.config = {} @@ -88,6 +90,7 @@ def __init__( "num_stages": num_stages, "indexing": indexing, "pid_type": pid_type, + "epilogue_subtiling": epilogue_subtiling, } for key, value in core_props.items(): if value is not None: @@ -217,6 +220,10 @@ def indexing(self) -> IndexingLiteral | list[IndexingLiteral]: "IndexingLiteral | list[IndexingLiteral]", self.config.get("indexing", []) ) + @property + def epilogue_subtiling(self) -> bool: + return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value] + def _to_hashable(x: object) -> object: if isinstance(x, list): diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 48aa1e97c..a56c3b7c9 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -338,6 +338,11 @@ class _Settings: _env_get_bool, "HELION_ALLOW_WARP_SPECIALIZE", True ) ) + allow_epilogue_subtiling: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_ALLOW_EPILOGUE_SUBTILING", False + ) + ) debug_dtype_asserts: bool = dataclasses.field( default_factory=functools.partial( _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False @@ -394,6 +399,7 @@ class Settings(_Settings): "Accepts HELION_AUTOTUNE_CONFIG_OVERRIDES='{\"num_warps\":4}'." ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", + "allow_epilogue_subtiling": "If True, allow epilogue subtiling on TMA stores for CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": ( diff --git a/test/test_examples.expected b/test/test_examples.expected index e665d87b0..6186d1cad 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -3784,6 +3784,279 @@ def layer_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, # src[layer_norm.py:N]: return grad_x, grad_weight, None return (grad_x, grad_weight, None) +--- assertExpectedJournal(TestExamples.test_layernorm_bwd_dwdb_no_bias) +from __future__ import annotations + +import torch +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit + # src[layer_norm.py:N]: for tile_n in hl.tile(n): +def _helion_layer_norm_bwd_dwdb(x, grad_out, mean, rstd, dw, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + # src[layer_norm.py:N]: rows = hl.arange(0, m) + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[layer_norm.py:N]: x_blk = x[rows, tile_n].to(torch.float32) + rows = tl.arange(0, 32) + load = tl.load(x + (rows[:, None] * 64 + indices_0[None, :] * 1), None) + # src[layer_norm.py:N]: dy_blk = grad_out[rows, tile_n].to(torch.float32) + v_0 = tl.cast(load, tl.float32) + load_1 = tl.load(grad_out + (rows[:, None] * 64 + indices_0[None, :] * 1), None) + # src[layer_norm.py:N]: mean_vec = mean[rows] + v_1 = tl.cast(load_1, tl.float32) + # src[layer_norm.py:N]: rstd_vec = rstd[rows] + mean_vec = tl.load(mean + rows * 1, None) + # src[layer_norm.py:N]: x_hat_blk = (x_blk - mean_vec[:, None]) * rstd_vec[:, None] + rstd_vec = tl.load(rstd + rows * 1, None) + subscript = mean_vec[:, None] + v_2 = v_0 - subscript + subscript_1 = rstd_vec[:, None] + # src[layer_norm.py:N]: dw_tile = torch.sum(dy_blk * x_hat_blk, dim=0).to(weight.dtype) + v_3 = v_2 * subscript_1 + v_4 = v_1 * v_3 + sum_1 = tl.cast(tl.sum(v_4, 0), tl.float32) + # src[layer_norm.py:N]: dw[tile_n] = dw_tile + v_5 = tl.cast(sum_1, tl.float16) + tl.store(dw + indices_0 * 1, v_5, None) + +def layer_norm_bwd_dwdb(grad_out: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, compute_bias_grad: hl.constexpr=True, *, _launcher=_default_launcher): + """ + Compute gradients for weight (dW) and optionally bias (dB) parameters. + + This kernel performs reduction across the batch dimension (M) to accumulate + gradients for each feature dimension's weight and bias parameters. + + Args: + grad_out: Gradient w.r.t layer norm output [M, N] + x: Original input tensor [M, N] + mean: Per-sample mean computed in forward pass [M] + rstd: Per-sample reciprocal standard deviation from forward pass [M] + weight: Weight parameter (used only for dtype/device info) [N] + compute_bias_grad: Whether to compute bias gradient (default: True) + + Returns: + (grad_weight, grad_bias): Gradients for weight and bias (if computed), both shape [N] + grad_bias is None if compute_bias_grad is False + # src[layer_norm.py:N]: m, n = x.shape + """ + # src[layer_norm.py:N]: n = hl.specialize(n) + m, n = x.shape + # src[layer_norm.py:N]: dw = torch.empty([n], dtype=weight.dtype, device=weight.device) + n = 64 + # src[layer_norm.py:N]: if compute_bias_grad: + # src[layer_norm.py:N]: db = torch.empty([n], dtype=weight.dtype, device=weight.device) + # src[layer_norm.py:N]: else: + # src[layer_norm.py:N-N]: ... + dw = torch.empty([n], dtype=weight.dtype, device=weight.device) + # src[layer_norm.py:N]: db = torch.empty([n], dtype=weight.dtype, device=weight.device) + if False: + db = torch.empty([n], dtype=weight.dtype, device=weight.device) + # src[layer_norm.py:N]: db = None + else: + # src[layer_norm.py:N]: for tile_n in hl.tile(n): + db = None + _BLOCK_SIZE_0 = 32 + # src[layer_norm.py:N]: if compute_bias_grad: + # src[layer_norm.py:N]: return dw, db + _launcher(_helion_layer_norm_bwd_dwdb, (triton.cdiv(64, _BLOCK_SIZE_0),), x, grad_out, mean, rstd, dw, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + # src[layer_norm.py:N]: return dw, db + if False: + # src[layer_norm.py:N]: return dw, None + return (dw, db) + return (dw, None) + +--- assertExpectedJournal(TestExamples.test_layernorm_bwd_dx) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit + # src[layer_norm.py:N]: for tile_m in hl.tile(m): +def _helion_layer_norm_bwd_dx(x, grad_out, weight, mean, rstd, grad_x, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[layer_norm.py:N]: x_tile = x[tile_m, :].to(torch.float32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + load = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[layer_norm.py:N]: dy_tile = grad_out[tile_m, :].to(torch.float32) + v_0 = tl.cast(load, tl.float32) + load_1 = tl.load(grad_out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None) + # src[layer_norm.py:N]: w = weight[:].to(torch.float32) + v_1 = tl.cast(load_1, tl.float32) + load_2 = tl.load(weight + indices_1 * 1, None) + # src[layer_norm.py:N]: mean_tile = mean[tile_m] + v_2 = tl.cast(load_2, tl.float32) + # src[layer_norm.py:N]: rstd_tile = rstd[tile_m] + mean_tile = tl.load(mean + indices_0 * 1, None) + # src[layer_norm.py:N]: x_hat = (x_tile - mean_tile[:, None]) * rstd_tile[:, None] + rstd_tile = tl.load(rstd + indices_0 * 1, None) + subscript = mean_tile[:, None] + v_3 = v_0 - subscript + subscript_1 = rstd_tile[:, None] + # src[layer_norm.py:N]: wdy = w * dy_tile + v_4 = v_3 * subscript_1 + v_5 = v_2[None, :] + # src[layer_norm.py:N]: c1 = torch.sum(x_hat * wdy, dim=-1) / n + v_6 = v_5 * v_1 + v_7 = v_4 * v_6 + sum_1 = tl.cast(tl.sum(v_7, 1), tl.float32) + v_8 = 0.015625 + # src[layer_norm.py:N]: c2 = torch.sum(wdy, dim=-1) / n + v_9 = sum_1 * v_8 + sum_2 = tl.cast(tl.sum(v_6, 1), tl.float32) + v_10 = 0.015625 + # src[layer_norm.py:N]: dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_tile[:, None] + v_11 = sum_2 * v_10 + subscript_2 = v_9[:, None] + v_12 = v_4 * subscript_2 + subscript_3 = v_11[:, None] + v_13 = v_12 + subscript_3 + v_14 = v_6 - v_13 + subscript_4 = rstd_tile[:, None] + # src[layer_norm.py:N]: grad_x[tile_m, :] = dx.to(x.dtype) + v_15 = v_14 * subscript_4 + v_16 = tl.cast(v_15, tl.float16) + tl.store(grad_x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_16, None) + +def layer_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, *, _launcher=_default_launcher): + """ + Compute gradient for input tensor (dX). + + This kernel computes per-sample gradients by performing reductions across + the feature dimension (N) for each sample in the batch. + + Args: + grad_out: Gradient w.r.t layer norm output [M, N] + x: Original input tensor [M, N] + weight: Weight parameter [N] + mean: Per-sample mean computed in forward pass [M] + rstd: Per-sample reciprocal standard deviation from forward pass [M] + + Returns: + grad_x: Gradient w.r.t input tensor, shape [M, N] + # src[layer_norm.py:N]: m, n = x.shape + """ + # src[layer_norm.py:N]: grad_x = torch.empty_like(x) + m, n = x.shape + # src[layer_norm.py:N]: for tile_m in hl.tile(m): +>>>>>>> 6bf2616 ([wip] Add output source annotations) + grad_x = torch.empty_like(x) + num_blocks = (x.size(0) + m_block - 1) // m_block + grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32) + grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 64 +<<<<<<< HEAD + _launcher(_helion_layer_norm_bwd, (triton.cdiv(32, _BLOCK_SIZE_0),), weight, x, grad_out, mean, rstd, grad_x, grad_weight_blocks, grad_bias_blocks, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + grad_weight = grad_weight_blocks.sum(0).to(weight.dtype) + if True: + grad_bias = grad_bias_blocks.sum(0).to(weight.dtype) + return (grad_x, grad_weight, grad_bias) + return (grad_x, grad_weight, None) + +--- assertExpectedJournal(TestExamples.test_layernorm_bwd_no_bias) +from __future__ import annotations + +import torch +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_layer_norm_bwd(weight, x, grad_out, mean, rstd, grad_x, grad_weight_blocks, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_3 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + grad_w_acc = tl.full([64], 0, tl.float32) + load = tl.load(weight + indices_3[None, :] * 1, None) + v_0 = tl.cast(load, tl.float32) + tile_end = offset_0 + _BLOCK_SIZE_0 + for offset_1 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32)): + indices_1 = offset_1 + tl.arange(0, 1).to(tl.int32) + v_0_copy = v_0 + grad_w_acc_copy = grad_w_acc + v_0_copy_0 = v_0_copy + grad_w_acc_copy_0 = grad_w_acc_copy + load_1 = tl.load(x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), None) + v_1 = tl.cast(load_1, tl.float32) + load_2 = tl.load(grad_out + (indices_1[:, None] * 64 + indices_3[None, :] * 1), None) + v_2 = tl.cast(load_2, tl.float32) + mean_mb = tl.load(mean + indices_1 * 1, None) + rstd_mb = tl.load(rstd + indices_1 * 1, None) + subscript = mean_mb[:, None] + v_3 = v_1 - subscript + subscript_1 = rstd_mb[:, None] + v_4 = v_3 * subscript_1 + v_5 = v_2 * v_4 + sum_1 = tl.cast(tl.sum(v_5, 0), tl.float32) + grad_w_acc = grad_w_acc_copy_0 + sum_1 + v_7 = v_0_copy_0 * v_2 + v_8 = v_4 * v_7 + sum_2 = tl.cast(tl.sum(v_8, 1), tl.float32) + v_9 = 0.015625 + v_10 = sum_2 * v_9 + sum_3 = tl.cast(tl.sum(v_7, 1), tl.float32) + v_11 = 0.015625 + v_12 = sum_3 * v_11 + subscript_2 = v_10[:, None] + v_13 = v_4 * subscript_2 + subscript_3 = v_12[:, None] + v_14 = v_13 + subscript_3 + v_15 = v_7 - v_14 + subscript_4 = rstd_mb[:, None] + v_16 = v_15 * subscript_4 + v_17 = tl.cast(v_16, tl.float16) + tl.store(grad_x + (indices_1[:, None] * 64 + indices_3[None, :] * 1), v_17, None) + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(grad_weight_blocks + (tile_id * 64 + indices_3 * 1), grad_w_acc, None) + +def layer_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, compute_bias_grad: hl.constexpr=True, *, _launcher=_default_launcher): + """ + Compute gradients for weight (dW) and optionally bias (dB) parameters. + + This kernel performs reduction across the batch dimension (M) to accumulate + gradients for each feature dimension's weight and bias parameters. + + Args: + grad_out: Gradient w.r.t layer norm output [M, N] + x: Original input tensor [M, N] + mean: Per-sample mean computed in forward pass [M] + rstd: Per-sample reciprocal standard deviation from forward pass [M] + weight: Weight parameter (used only for dtype/device info) [N] + compute_bias_grad: Whether to compute bias gradient (default: True) + + Returns: + (grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed) + grad_bias is None if compute_bias_grad is False + """ + m_block = 32 + n = 64 + grad_x = torch.empty_like(x) + num_blocks = (x.size(0) + m_block - 1) // m_block + grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32) + grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 64 + _launcher(_helion_layer_norm_bwd, (triton.cdiv(32, _BLOCK_SIZE_0),), weight, x, grad_out, mean, rstd, grad_x, grad_weight_blocks, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + grad_weight = grad_weight_blocks.sum(0).to(weight.dtype) + if False: + grad_bias = grad_bias_blocks.sum(0).to(weight.dtype) + return (grad_x, grad_weight, grad_bias) + return (grad_x, grad_weight, None) +======= + # src[layer_norm.py:N]: return grad_x + _launcher(_helion_layer_norm_bwd_dx, (triton.cdiv(32, _BLOCK_SIZE_0),), x, grad_out, weight, mean, rstd, grad_x, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return grad_x + --- assertExpectedJournal(TestExamples.test_layernorm_no_bias) from __future__ import annotations @@ -6085,7 +6358,7 @@ import triton.language as tl from torch._inductor.runtime import triton_helpers from helion.runtime import default_launcher as _default_launcher -import test.test_examples as _global_source0 +import __main__ as _global_source0 @triton.jit def _helion_matmul(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): @@ -6168,7 +6441,7 @@ import triton.language as tl from torch._inductor.runtime import triton_helpers from helion.runtime import default_launcher as _default_launcher -import test.test_examples as _global_source0 +import __main__ as _global_source0 @triton.jit def _helion_matmul(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): @@ -6248,7 +6521,7 @@ import triton.language as tl from torch._inductor.runtime import triton_helpers from helion.runtime import default_launcher as _default_launcher -import test.test_examples as _global_source0 +import __main__ as _global_source0 @triton.jit def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): diff --git a/test/test_examples.py b/test/test_examples.py index 61eb0a07d..058011aae 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -269,6 +269,7 @@ def test_template_via_closure1(self): ) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @patch.object(helion._compiler.utils, "_allow_epilogue_subtiling", lambda: False) def test_template_via_closure2(self): args = ( torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), diff --git a/test/test_matmul.expected b/test/test_matmul.expected index 0284e8237..772cb9a03 100644 --- a/test/test_matmul.expected +++ b/test/test_matmul.expected @@ -265,6 +265,384 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: return out return out +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +# src[test_matmul.py:N]: def matmul_static_shapes_epilogue_subtiling(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# src[test_matmul.py:N]: m, k = x.size() +# src[test_matmul.py:N]: k2, n = y.size() +# src[test_matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul_static_shapes_epilogue_subtiling(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [4096, 4096], [4096, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [4096, 4096], [4096, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + out_desc = tl.make_tensor_descriptor(out, [4096, 4096], [4096, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + out_desc_1 = tl.make_tensor_descriptor(out, [4096, 4096], [4096, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1 // 2]) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(4096, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(4096, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 4096, _BLOCK_SIZE_2, warp_specialize=False): + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + mm = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + out_desc_1.store([offset_0, offset_1], v_2) + out_desc_1.store([offset_0, offset_1 + _BLOCK_SIZE_1 // 2], v_3) + +def matmul_static_shapes_epilogue_subtiling(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 64 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes_epilogue_subtiling, (triton.cdiv(4096, _BLOCK_SIZE_0) * triton.cdiv(4096, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_pointer_mask_subtile_0) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(130, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(130, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 130 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 130 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 130, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < 130 + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 130 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * 130 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + v_2 = tl.cast(acc, tl.bfloat16) + tl.store(out + (indices_0[:, None] * 130 + indices_1[None, :] * 1), v_2, mask_0[:, None] & mask_1[None, :]) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(130, _BLOCK_SIZE_0) * triton.cdiv(130, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_pointer_mask_subtile_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(130, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(130, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 130 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 130 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 130, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < 130 + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 130 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * 130 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + mask_n_0 = indices_n_0 < 130 + mask_n_1 = indices_n_1 < 130 + tl.store(out + (indices_0[:, None] * 130 + indices_n_0[None, :] * 1), v_2, mask_0[:, None] & mask_n_0[None, :]) + tl.store(out + (indices_0[:, None] * 130 + indices_n_1[None, :] * 1), v_3, mask_0[:, None] & mask_n_1[None, :]) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(130, _BLOCK_SIZE_0) * triton.cdiv(130, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_subtile_0) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +# src[test_matmul.py:N]: def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# src[test_matmul.py:N]: m, k = x.size() +# src[test_matmul.py:N]: k2, n = y.size() +# src[test_matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [128, 128], [128, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + out_desc = tl.make_tensor_descriptor(out, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + v_2 = tl.cast(acc, tl.bfloat16) + out_desc.store([offset_0, offset_1], v_2) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_subtile_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + tl.store(out + (indices_0[:, None] * 128 + indices_n_0[None, :] * 1), v_2, None) + tl.store(out + (indices_0[:, None] * 128 + indices_n_1[None, :] * 1), v_3, None) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + --- assertExpectedJournal(TestMatmul.test_matmul_packed_rhs) from __future__ import annotations diff --git a/test/test_matmul.py b/test/test_matmul.py index 026485ec3..c4db95906 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -5,6 +5,8 @@ from unittest.mock import patch import torch +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.common_utils import parametrize import helion from helion import Config @@ -55,20 +57,23 @@ def matmul_without_addmm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out -@helion.kernel(static_shapes=True) -def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - m, k = x.size() - k2, n = y.size() - assert k == k2, f"size mismatch {k} != {k2}" - out = torch.empty( - [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device - ) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = acc - return out +def matmul_static_shapes_wrapper(epilogue_subtiling: bool = False): + @helion.kernel(static_shapes=True, allow_epilogue_subtiling=epilogue_subtiling) + def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k2, n = y.size() + assert k == k2, f"size mismatch {k} != {k2}" + out = torch.empty( + [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + ) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + return out + + return matmul_static_shapes class TestMatmul(RefEagerTestBase, TestCase): @@ -152,7 +157,7 @@ def test_matmul_static_shapes0(self): torch.randn([128, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -167,7 +172,7 @@ def test_matmul_static_shapes1(self): torch.randn([128, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -181,7 +186,7 @@ def test_matmul_static_shapes2(self): torch.randn([127, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -195,7 +200,7 @@ def test_matmul_static_shapes3(self): torch.randn([128, 127], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -339,6 +344,50 @@ def matmul_with_packed_b( torch.testing.assert_close(C, expected, atol=5e-2, rtol=1e-3) self.assertExpectedJournal(code) + @unittest.skipIf( + DEVICE.type != "cuda" or torch.cuda.get_device_capability() < (10, 0), + "Epilogue Subtiling requires CUDA compute capability >= 10.0", + ) + @parametrize("subtile", (0, 2)) + def test_matmul_epilogue_subtile_tensor_descriptor(self, subtile: int): + args = ( + torch.randn([128, 128], device=DEVICE, dtype=torch.bfloat16), + torch.randn([128, 128], device=DEVICE, dtype=torch.bfloat16), + ) + code, output = code_and_output( + matmul_static_shapes_wrapper(epilogue_subtiling=True), + args, + block_sizes=[32, 32, 32], + l2_grouping=4, + indexing="tensor_descriptor", + epilogue_subtiling=[subtile], + ) + torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) + self.assertExpectedJournal(code) + + @unittest.skipIf( + DEVICE.type != "cuda" or torch.cuda.get_device_capability() < (10, 0), + "Epilogue Subtiling requires CUDA compute capability >= 10.0", + ) + @parametrize("subtile", (0, 2)) + def test_matmul_epilogue_subtile_pointer_mask(self, subtile: int): + args = ( + torch.randn([130, 130], device=DEVICE, dtype=torch.bfloat16), + torch.randn([130, 130], device=DEVICE, dtype=torch.bfloat16), + ) + code, output = code_and_output( + matmul_static_shapes_wrapper(epilogue_subtiling=True), + args, + block_sizes=[32, 32, 32], + l2_grouping=4, + indexing="pointer", + epilogue_subtiling=[subtile], + ) + torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) + self.assertExpectedJournal(code) + + +instantiate_parametrized_tests(TestMatmul) if __name__ == "__main__": unittest.main() diff --git a/test/test_type_propagation.expected b/test/test_type_propagation.expected index 2b60ae525..fe00193c6 100644 --- a/test/test_type_propagation.expected +++ b/test/test_type_propagation.expected @@ -1025,35 +1025,35 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] Tensor: Resulting matrix of shape [m, n]. """ m, k = - # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) + # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x') x.size() k2, n = - # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) + # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='y'), key='size') # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y') y.size() assert - # Compare: LiteralType(True) SourceOrigin(location=) - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Compare: LiteralType(True) SourceOrigin(location=) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) k == - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) k2, - # JoinedStr: str SourceOrigin(location=) + # JoinedStr: str SourceOrigin(location=) f'size mismatch {k} != {k2}' out = - # Call: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Call: TensorType([512, 512], torch.float32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.empty( - # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) + # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) [ - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) m, - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) n], dtype= - # Call: LiteralType(torch.float32) SourceOrigin(location=) + # Call: LiteralType(torch.float32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.promote_types) AttributeOrigin(value=GlobalOrigin(name='torch'), key='promote_types') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.promote_types( @@ -1069,26 +1069,26 @@ x.device) # For: loop_type=GRID for tile_m, tile_n in - # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) + # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.tile( - # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) + # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) [ - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) m, - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) n]): acc = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(zeros) AttributeOrigin(value=GlobalOrigin(name='hl'), key='zeros') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.zeros( - # List: SequenceType([TileIndexType(0), TileIndexType(1)]) DeviceOrigin(location=) + # List: SequenceType([TileIndexType(0), TileIndexType(1)]) DeviceOrigin(location=) [ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n], dtype= # Attribute: LiteralType(torch.float32) AttributeOrigin(value=GlobalOrigin(name='torch'), key='float32') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') @@ -1096,61 +1096,61 @@ torch.float32) # For: loop_type=DEVICE for tile_k in - # Call: IterType(TileIndexType(2)) DeviceOrigin(location=) + # Call: IterType(TileIndexType(2)) DeviceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.tile( - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) k): acc = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.addmm) AttributeOrigin(value=GlobalOrigin(name='torch'), key='addmm') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.addmm( - # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) acc, - # Subscript: TensorType([block_size_0, block_size_2], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_2], torch.float32) DeviceOrigin(location=) # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x') x[ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(2) DeviceOrigin(location=) + # Name: TileIndexType(2) DeviceOrigin(location=) tile_k], - # Subscript: TensorType([block_size_2, block_size_1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_2, block_size_1], torch.float32) DeviceOrigin(location=) # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y') y[ - # Name: TileIndexType(2) DeviceOrigin(location=) + # Name: TileIndexType(2) DeviceOrigin(location=) tile_k, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n]) - # Subscript: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) - # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) out[ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n] = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Name: CallableType() ArgumentOrigin(name='epilogue') epilogue( - # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) acc, - # Tuple: SequenceType((TileIndexType(0), TileIndexType(1))) DeviceOrigin(location=) + # Tuple: SequenceType((TileIndexType(0), TileIndexType(1))) DeviceOrigin(location=) ( - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n)) return - # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) out def for_loop_0(arg0_1: "f32[u0, u1]"): - # File: .../matmul.py:61 in matmul, code: for tile_k in hl.tile(k): + # File: .../matmul.py:62 in matmul, code: for tile_k in hl.tile(k): _new_var: "f32[u0, u1]" = helion_language__tracing_ops__new_var(arg0_1) - # File: .../matmul.py:62 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + # File: .../matmul.py:63 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) x: "f32[512, 512]" = helion_language__tracing_ops__host_tensor('x') sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(arg0_1, 0) block_size_2: "Sym(u2)" = helion_language__tracing_ops__get_symnode('block_size_2') @@ -1162,17 +1162,17 @@ def for_loop_0(arg0_1: "f32[u0, u1]"): return [acc] def root_graph_1(): - # File: .../matmul.py:60 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # File: .../matmul.py:61 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32, None) - # File: .../matmul.py:61 in matmul, code: for tile_k in hl.tile(k): + # File: .../matmul.py:62 in matmul, code: for tile_k in hl.tile(k): _for_loop = helion_language__tracing_ops__for_loop(0, [0], [512], [acc]) getitem: "f32[u0, u1]" = _for_loop[0]; _for_loop = None _phi: "f32[u0, u1]" = helion_language__tracing_ops__phi(acc, getitem); acc = getitem = None - # File: .../matmul.py:63 in matmul, code: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + # File: .../matmul.py:64 in matmul, code: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) out: "f32[512, 512]" = helion_language__tracing_ops__host_tensor('out') store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], _phi, None); out = block_size_0 = block_size_1 = _phi = store = None return None