diff --git a/examples/matmul.py b/examples/matmul.py index b3c3ca4d8..bc45b17cd 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -28,12 +28,13 @@ @helion.kernel( # static_shapes=True gives a performance boost for matmuls static_shapes=True, - # Disable autotung over unrolling/range_num_stages + # Disable autotuning over range_num_stages # tl.dot is pipelined with num_stages autotune_config_overrides={ "range_unroll_factors": [0, 0], "range_num_stages": [0, 0], }, + allow_epilogue_subtiling=True, ) def matmul( x: Tensor, diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 6cec94000..ffb929e93 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -462,8 +462,9 @@ def tensor_descriptor_arg( self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt] ) -> TensorDescriptorArg: host_function = HostFunction.current() - block_size_expr = ", ".join(map(self.literal_expr, block_size)) + block_size_expr = ", ".join(self.literal_expr(dim) for dim in block_size) key = (fake_value, block_size_expr) + if key not in self._tensor_descriptor_args: origin = host_function.tensor_to_origin[fake_value] desc_name = self.new_var(origin.suggest_var_name() + "_desc") @@ -556,22 +557,6 @@ def _format_constexpr_value(self, value: object) -> str: if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): value = value._sympy_() - # Handle sympy expressions (sanitize by replacing triton_helpers functions) - if isinstance(value, sympy.Expr): - sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue] - lambda node: isinstance(node, sympy.Function) - and getattr(node.func, "__name__", "") - == "triton_helpers.div_floor_integer", - lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue] - ).replace( # pyright: ignore[reportAttributeAccessIssue] - lambda node: isinstance(node, sympy.Function) - and getattr(node.func, "__name__", "") - == "triton_helpers.remainder_integer", - lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue] - ) - expr = cast("sympy.Expr", sanitized) - return HostFunction.current().sympy_expr(expr) - return HostFunction.current().literal_expr(value) def _tensor_property( @@ -749,11 +734,19 @@ def current() -> DeviceFunction: class HelionTritonPrinter(TritonPrinter): - """Custom Triton printer that avoids wrapping float literals in tl.full(). - - Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value - via tl.full([], , tl.float64). We override this to emit the raw numeric - literal, letting downstream type promotion and casts handle dtype. + """Custom Triton printer that does the following: + + - Avoids wrapping float literals in tl.full(). + Inductor's default TritonPrinter prints SymPy Float as a 0-D Triton value + via tl.full([], , tl.float64). We override this to emit the raw numeric + literal, letting downstream type promotion and casts handle dtype. + + - Avoids triton_helpers.div_floor_integer(...) calls when both operands are + provably non-negative integers. TritonPrinter by default converts + floor(u1/2) to triton_helpers.div_floor_integer(...). We override this to + emit u1 // 2 only when the numerator is known to be non-negative and the + denominator is a positive integer, so that we keep helper calls for cases + that rely on floor semantics with mixed signs. """ def _print_Float(self, expr: sympy.Expr) -> str: @@ -762,6 +755,59 @@ def _print_Float(self, expr: sympy.Expr) -> str: def _print_ToFloat(self, expr: sympy.Expr) -> str: return f"{expr} + 0.0" + def _is_nonnegative(self, expr: sympy.Expr) -> bool: + if expr.is_nonnegative is True or expr.is_zero is True: + return True + if expr.is_positive is True: + return True + try: + host_fn = HostFunction.current() + except NoCurrentFunction: + host_fn = None + if host_fn is not None: + origin_info = host_fn.expr_to_origin.get(expr) + if origin_info and isinstance( + origin_info.origin, (BlockSizeOrigin, TensorSizeOrigin) + ): + return True + if isinstance(expr, sympy.Symbol) and expr.name.startswith("_BLOCK_SIZE_"): + return True + if isinstance(expr, sympy.Number): + return bool(expr >= 0) + return False + + def _format_trunc_div(self, lhs: sympy.Expr, rhs: sympy.Expr) -> str: + lhs_str = self._print(lhs) + rhs_str = self._print(rhs) + if not (lhs.is_Integer or lhs.is_Symbol): + lhs_str = f"({lhs_str})" + if not (rhs.is_Integer or rhs.is_Symbol): + rhs_str = f"({rhs_str})" + return f"{lhs_str} // {rhs_str}" + + def _print_floor(self, expr: sympy.Expr) -> str: + inner = expr.args[0] + numer, denom = inner.as_numer_denom() # pyright: ignore[reportAttributeAccessIssue] + if ( + isinstance(denom, sympy.Integer) + and denom > 1 + and isinstance(numer, sympy.Expr) + and self._is_nonnegative(numer) + ): + return self._format_trunc_div(numer, denom) + return super()._print_floor(expr) + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + if ( + isinstance(rhs, sympy.Integer) + and rhs > 0 + and isinstance(lhs, sympy.Expr) + and self._is_nonnegative(lhs) + ): + return self._format_trunc_div(lhs, rhs) + return super()._print_FloorDiv(expr) + def texpr(expr: sympy.Expr) -> str: return HelionTritonPrinter().doprint(expr) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index bb119db86..e529e1260 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -63,6 +63,7 @@ from .type_propagation import _eval_binary from .type_propagation import _eval_compare from .type_propagation import _eval_unary +from .utils import _use_epilogue_subtile if TYPE_CHECKING: from collections.abc import Callable @@ -1191,6 +1192,11 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: total_load_count, loads_without_eviction_policy, store_count ) + if _use_epilogue_subtile(): + for graph in device_ir.graphs: + # Epilogue subtiling only for Blackwell + epilogue_subtiling_pass(graph.graph, store_count) + return device_ir @@ -1348,3 +1354,69 @@ def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: user.args = tuple(new_args) if len(node.users) == 0: graph.erase_node(node) + + +def epilogue_subtiling_pass(graph: torch.fx.Graph, store_count: int) -> None: + """ + Replace epilogue subtile with a tunable value. + """ + if store_count == 0: + return + + from ..autotuner.config_fragment import EnumFragment + from ..autotuner.config_fragment import ListOf + from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES + from .inductor_lowering import PointwiseLowering + + env = CompileEnvironment.current() + # Register a tunable for epilogue subtile for all device stores + fragment = ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count + ) + env.config_spec.epilogue_subtiling = fragment + + def collect_pointwise_epilogue_nodes( + store_node: torch.fx.Node, + ) -> dict[torch.fx.Node, None]: + """Recursively collect all pointwise nodes that can be subtiled in the epilogue. + + Starting from a store node, traverse backwards through all input nodes, + collecting pointwise operations until we hit non-pointwise nodes. + Only include pointwise nodes that have a single user to ensure they can be fused. + """ + # dict to preserve order + pointwise_nodes = {} + visited = set() + stack = [store_node.args[2]] # Start with the value being stored + + while stack: + current = stack.pop() + if current in visited or not isinstance(current, torch.fx.Node): + continue + + visited.add(current) + + lowering = current.meta.get("lowering") + # Check if this is a pointwise operation with only one user + if isinstance(lowering, PointwiseLowering) and len(current.users) == 1: + if current not in pointwise_nodes: + pointwise_nodes[current] = None + stack.extend(current.all_input_nodes) + + return pointwise_nodes + + from ..language import store as store_api + + stores = set() + + for node in graph.nodes: + if node.op == "call_function" and node.target == store_api: + stores.add(node) + # Collect all pointwise nodes that can be subtiled in the epilogue + pointwise_nodes = collect_pointwise_epilogue_nodes(node) + if pointwise_nodes: + # Mark all collected pointwise nodes for epilogue subtiling + for pw_node in pointwise_nodes: + pw_node.meta["epilogue_subtile"] = True + # Store the set of pointwise nodes in the store node's metadata + node.meta["pointwise_epilogue_nodes"] = pointwise_nodes diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index b6c8eaa66..23fc0ced6 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -15,10 +15,12 @@ from .. import exc from .._compat import get_tensor_descriptor_fn_name from .ast_extension import expr_from_string +from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .device_function import DeviceFunction from .host_function import HostFunction from .tile_strategy import DeviceLoopState +from .utils import _use_epilogue_subtile from .utils import compute_slice_size from .variable_origin import BlockSizeOrigin @@ -104,6 +106,143 @@ def _get_tile_with_offset_info( return None +# Codegen pointwise ops to sub-tiles/tile +def _apply_pointwise_to_subtile( + state: CodegenState, tile_values: list[ast.AST] +) -> list[ast.AST]: + from torch._inductor import ir + + from .inductor_lowering import PointwiseLowering + from .inductor_lowering import install_inductor_kernel_handlers + + if not ( + isinstance(state.fx_node, torch.fx.Node) + and "pointwise_epilogue_nodes" in state.fx_node.meta + ): + return tile_values + + pointwise_nodes = list(reversed(state.fx_node.meta["pointwise_epilogue_nodes"])) + for pw_node in pointwise_nodes: + lowering = pw_node.meta["lowering"] + assert isinstance(lowering, PointwiseLowering) + + buffer = lowering.buffer + assert isinstance(buffer.data, ir.Pointwise) + + for i, tile in enumerate(tile_values): + codegen = state.codegen + subtile_var = codegen.lift(tile, prefix="subtile") + + with install_inductor_kernel_handlers( + codegen, dict.fromkeys(lowering.input_names, subtile_var) + ): + # Generate the pointwise operation + indices = [ + sympy.Symbol(f"i{n}") for n in range(len(buffer.data.ranges)) + ] + from .inductor_lowering import _unpack_opsvalue + + result_name = _unpack_opsvalue(buffer.data.inner_fn(indices)) + tile_values[i] = expr_from_string(result_name) + + return tile_values + + +# Get sub-tile size from autotune config +def _get_subtile_split(state: CodegenState) -> int: + epilogue_subtiles = state.config.epilogue_subtiling + idx: int = int(state.device_function.device_store_index) + if idx > len(epilogue_subtiles): + return -1 + + return epilogue_subtiles[idx - 1] + + +# Common func for output shape of tensor descriptor/subscript indexing +# for 2D tensors +def _get_output_shape( + indexing: BlockedSubscriptIndexing | SubscriptIndexing, + state: CodegenState, + fake_tensor: torch.Tensor | None = None, + subscript: list[object] | None = None, +) -> list[int | torch.SymInt]: + if isinstance(indexing, SubscriptIndexing): + assert fake_tensor is not None and subscript is not None + # Pointer Indexing + output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript, state) + block_m, block_n = output_shape + else: + assert isinstance(indexing, BlockedSubscriptIndexing) + output_shape = indexing.block_shape + + return output_shape + + +def _can_epilogue_subtile_with_output_shape( + output_shape: list[int | torch.SymInt], +) -> bool: + env = CompileEnvironment.current() + config = DeviceFunction.current().config + + if len(output_shape) != 2: + return False + + block_m, block_n = output_shape + block_n_hint = env.size_hint(block_n) + block_idx = env.get_block_id(block_n) + if not block_idx: + return False + block_size = env.block_sizes[block_idx].from_config(config) + + if not block_size: + return False + + return not (block_n_hint % 2 != 0 or block_size <= 16) + + +def _get_accumulator_subtiles( + state: CodegenState, + store_value: ast.AST, + block_m_str: str, + block_n_half_str: str, + fake_tensor: torch.Tensor | None = None, + subscript: list[object] | None = None, +) -> tuple[ast.AST, ast.AST]: + # Get the output shape from SubscriptIndexing + + codegen = state.codegen + + block_n_half_expr = expr_from_string(block_n_half_str) + + # Lift the store value into a temporary variable for reuse + acc_var = codegen.lift(store_value, prefix="acc") + + # Reshape and split the accumulator + reshape_expr = expr_from_string( + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", + acc=acc_var, + dim_m=expr_from_string(block_m_str), + dim_half=block_n_half_expr, + ) + reshape_var = codegen.lift(reshape_expr, prefix="acc") + + acc0_name = codegen.tmpvar(prefix="acc") + acc1_name = codegen.tmpvar(prefix="acc") + codegen.add_statement( + statement_from_string( + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", + acc=reshape_var, + ) + ) + + # Apply pointwise operations to each subtile if present + acc0 = expr_from_string(acc0_name) + acc1 = expr_from_string(acc1_name) + acc0, acc1 = _apply_pointwise_to_subtile(state, [acc0, acc1]) + + return acc0, acc1 + + class IndexingStrategy: def codegen_load( self, @@ -125,6 +264,17 @@ def codegen_store( ) -> ast.AST: raise NotImplementedError + def codegen_store_subtile( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: BlockedSubscriptIndexing | SubscriptIndexing, + store_value: ast.AST, + extra_mask: ast.AST | None, + subscript: list[object] | None, + ) -> ast.AST | None: + raise NotImplementedError + @staticmethod def select(indexing_literal: IndexingLiteral) -> IndexingStrategy: if indexing_literal == "pointer": @@ -178,6 +328,23 @@ def codegen_store( ) -> ast.AST: indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask) name = state.device_function.tensor_arg(fake_tensor).name + + # Try epilogue subtiling if enabled + if _use_epilogue_subtile(): + subtile_value = self.codegen_store_subtile( + state, + fake_tensor, + indexing, + value, + extra_mask=extra_mask, + subscript=subscript, + ) + + if subtile_value: + return subtile_value + + (value,) = _apply_pointwise_to_subtile(state, [value]) + return expr_from_string( f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", value=value, @@ -185,6 +352,148 @@ def codegen_store( mask=indexing.mask_expr, ) + def codegen_store_subtile( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: BlockedSubscriptIndexing | SubscriptIndexing, + store_value: ast.AST, + extra_mask: ast.AST | None, + subscript: list[object] | None, + ) -> ast.AST | None: + """Generate epilogue subtiling for pointer-based stores. + + This splits the store value and offsets to perform multiple smaller stores, + which can improve performance by reducing TMA overhead. + """ + assert isinstance(indexing, SubscriptIndexing) + + subtile_split = _get_subtile_split(state) + env = CompileEnvironment.current() + codegen = state.codegen + device_fn = state.device_function + output_shape = _get_output_shape(indexing, state, fake_tensor, subscript) + + if subtile_split <= 0 or not _can_epilogue_subtile_with_output_shape( + output_shape + ): + return None + + block_m, block_n = output_shape + + block_idx: int | None = env.get_block_id(block_n) + block_idx_m = env.get_block_id(block_m) + + if block_idx is None or block_idx_m is None: + return None + + block_n_str = device_fn.literal_expr(block_n) + block_n_half_str = f"({block_n_str} // {subtile_split})" + block_m_str = device_fn.literal_expr(block_m) + + acc0, acc1 = _get_accumulator_subtiles( + state, + store_value, + block_m_str=block_m_str, + block_n_half_str=block_n_half_str, + fake_tensor=fake_tensor, + subscript=subscript, + ) + + name = state.device_function.tensor_arg(fake_tensor).name + + # Generate sliced index variables for N dimension + # Get the index variable for the N dimension (block_idx) + offset_n_var = codegen.offset_var(block_idx) + + # Create sliced indices for each subtile + # First subtile: indices_n[:block_n_half] + index_n_0_name = codegen.tmpvar(prefix="indices_n") + codegen.add_statement( + statement_from_string( + f"{index_n_0_name} = ({offset_n_var} + tl.arange(0, {block_n_half_str})).to(tl.int32)" + ) + ) + + # Second subtile: indices_n[block_n_half:] + index_n_1_name = codegen.tmpvar(prefix="indices_n") + codegen.add_statement( + statement_from_string( + f"{index_n_1_name} = ({offset_n_var} + {block_n_half_str} + tl.arange(0, {block_n_half_str})).to(tl.int32)" + ) + ) + + # Reconstruct the offset expressions for each subtile + # We need to replace the N dimension index with the sliced versions + stride_n = state.device_function.tensor_stride(fake_tensor, -1).name + stride_m = state.device_function.tensor_stride(fake_tensor, -2).name + index_m_var = codegen.index_var(block_idx_m) + + # Build offset for first subtile + offset_0 = expr_from_string( + f"{index_m_var}[:, None] * {stride_m} + {index_n_0_name}[None, :] * {stride_n}" + ) + + # Build offset for second subtile (note: need to add block_n_half to base for second half) + offset_1 = expr_from_string( + f"{index_m_var}[:, None] * {stride_m} + {index_n_1_name}[None, :] * {stride_n}" + ) + + # Generate masks for each subtile if masking is needed + mask_0 = indexing.mask_expr + mask_1 = indexing.mask_expr + + if indexing.has_mask(): + # Need to slice the mask as well for N dimension + mask_n_var = codegen.mask_var(block_idx) + if mask_n_var is not None: + # Original mask structure: mask_m[:, None] & mask_n[None, :] + # Need to slice mask_n for each subtile + mask_n_0_name = codegen.tmpvar(prefix="mask_n") + mask_n_1_name = codegen.tmpvar(prefix="mask_n") + + codegen.add_statement( + statement_from_string( + f"{mask_n_0_name} = {index_n_0_name} < {stride_m}" + ) + ) + codegen.add_statement( + statement_from_string( + f"{mask_n_1_name} = {index_n_1_name} < {stride_m}" + ) + ) + + # Reconstruct masks with sliced components + mask_m_var = codegen.mask_var(block_idx_m) + if mask_m_var is not None: + mask_0 = expr_from_string( + f"{mask_m_var}[:, None] & {mask_n_0_name}[None, :]" + ) + mask_1 = expr_from_string( + f"{mask_m_var}[:, None] & {mask_n_1_name}[None, :]" + ) + else: + mask_0 = expr_from_string(f"{mask_n_0_name}[None, :]") + mask_1 = expr_from_string(f"{mask_n_1_name}[None, :]") + + # First subtile store + codegen.add_statement( + statement_from_string( + f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", + value=acc0, + offset=offset_0, + mask=mask_0, + ) + ) + + # Second subtile store - return as the result + return expr_from_string( + f"tl.store({name} + {{offset}}, {{value}}, {{mask}})", + value=acc1, + offset=offset_1, + mask=mask_1, + ) + class BlockPtrIndexingStrategy(IndexingStrategy): """Use block_ptr to load/store from tensors""" @@ -352,7 +661,6 @@ def codegen_load( ) assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) - # Load from tensor descriptor with permuted offsets load_expr = expr_from_string( f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" @@ -380,12 +688,12 @@ def codegen_store( return PointerIndexingStrategy().codegen_store( state, fake_tensor, subscript, value, extra_mask ) + assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) - + store_value = indexing.reshape_store(state, value) # Apply permutation to the value being stored if needed desc_arg = indexing.tensor_descriptor_arg(state) - store_value = indexing.reshape_store(state, value) if desc_arg.permutation is not None: # Apply permutation to the value @@ -394,11 +702,88 @@ def codegen_store( store_val=store_value, ) + if _use_epilogue_subtile(): + subtile_value = self.codegen_store_subtile( + state, + fake_tensor, + indexing, + store_value, + extra_mask=None, + subscript=None, + ) + + if subtile_value: + return subtile_value + (store_value,) = _apply_pointwise_to_subtile(state, [store_value]) + return expr_from_string( f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", value=store_value, ) + def codegen_store_subtile( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: BlockedSubscriptIndexing | SubscriptIndexing, + store_value: ast.AST, + extra_mask: ast.AST | None, + subscript: list[object] | None, + ) -> ast.AST | None: + assert isinstance(indexing, BlockedSubscriptIndexing) + + subtile_split = _get_subtile_split(state) + output_shape = _get_output_shape(indexing, state) + if subtile_split <= 0 or not _can_epilogue_subtile_with_output_shape( + output_shape + ): + return None + + codegen = state.codegen + + block_m, block_n = output_shape + device_fn = state.device_function + block_n_str = device_fn.literal_expr(block_n) + block_n_half_str = f"({block_n_str} // {subtile_split})" + block_m_str = device_fn.literal_expr(block_m) + + acc0, acc1 = _get_accumulator_subtiles( + state, + store_value, + block_m_str=block_m_str, + block_n_half_str=block_n_half_str, + ) + + indexing.block_shape[1] //= subtile_split + + desc_name = indexing.tensor_descriptor(state) + offset0 = expr_from_string(indexing.offsets[0]) + offset1 = expr_from_string(indexing.offsets[1]) + + # First subtile store + codegen.add_statement( + statement_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1, + value=acc0, + ) + ) + + offset1_shifted = expr_from_string( + "({offset} + {half})", + offset=expr_from_string(indexing.offsets[1]), + half=expr_from_string(block_n_half_str), + ) + + # Emit second subtile store as the expression returned to the caller + return expr_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1_shifted, + value=acc1, + ) + class StackIndexingStrategy: """ diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 49db6d664..f61b8b47d 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -87,6 +87,8 @@ def prepare_graph_lowerings(graph: torch.fx.Graph) -> None: + from .utils import _use_epilogue_subtile + with compile_lock: graph_lowering = GraphLowering( _LazyGraphModule({}, graph), @@ -104,6 +106,25 @@ def prepare_graph_lowerings(graph: torch.fx.Graph) -> None: with node.meta["location"]: prepare_node_lowering(graph_lowering, node) + if _use_epilogue_subtile(): + from ..language import store as store_api + + stores = set() + + for node in graph.nodes: + if node.op == "call_function" and node.target == store_api: + stores.add(node) + value_node = node.args[2] + if not isinstance(value_node, torch.fx.Node): + continue + lowering = value_node.meta.get("lowering") + if ( + isinstance(lowering, PointwiseLowering) + and len(value_node.users) == 1 + ): + value_node.meta["epilogue_subtile"] = True + node.meta["pointwise_in"] = value_node + def prepare_node_lowering( graph_lowering: GraphLowering, @@ -1396,6 +1417,12 @@ def _collect_multi_outputs( def run_node(self, n: Node) -> object: if n.op == "call_function": + # Skip codegen for nodes that can potentially be subtiled + if n.meta.get("epilogue_subtile", False): + arg = n.args[0] + assert isinstance(arg, Node) + return self.env[arg] + with self._set_current_node(n), n.meta["location"], V.set_current_node(n): try: lowering: Lowering = n.meta["lowering"] diff --git a/helion/_compiler/utils.py b/helion/_compiler/utils.py index 0992514af..85c14e97f 100644 --- a/helion/_compiler/utils.py +++ b/helion/_compiler/utils.py @@ -1,9 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import torch +import torch def compute_slice_size( @@ -29,3 +26,13 @@ def compute_slice_size( start = slice_obj.start if slice_obj.start is not None else 0 stop = slice_obj.stop if slice_obj.stop is not None else original_size return stop - start + + +def _use_epilogue_subtile() -> bool: + from .compile_environment import CompileEnvironment + + return ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (10, 0) + and CompileEnvironment.current().settings.allow_epilogue_subtiling + ) diff --git a/helion/_testing.py b/helion/_testing.py index 87f459558..603979b3c 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -683,6 +683,7 @@ def check_example( static_shapes: bool | None = None, atol: float = 1e-1, rtol: float = 1e-2, + allow_epilogue_subtiling: bool | None = None, **kwargs: object, ) -> str: """Helper used in unit tests to run a single example kernel and check its output.""" @@ -691,6 +692,10 @@ def check_example( assert static_shapes in (True, False) kernel_fn.settings.static_shapes = static_shapes + if allow_epilogue_subtiling is not None: + assert allow_epilogue_subtiling in (True, False) + kernel_fn.settings.allow_epilogue_subtiling = allow_epilogue_subtiling + code, result = code_and_output( kernel_fn, args, diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index bf7cf1270..22ab1c18a 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -52,10 +52,12 @@ "pid_type", "indexing", "load_eviction_policies", + "epilogue_subtiling", ] ) VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved") VALID_EVICTION_POLICIES = ("", "first", "last") +VALID_EPILOGUE_SUBTILE_SIZES = (0, 2) @dataclasses.dataclass @@ -110,6 +112,11 @@ class ConfigSpec: EnumFragment(choices=ConfigSpec._valid_indexing_types()), length=0 ) ) + epilogue_subtiling: ListOf = dataclasses.field( + default_factory=lambda: ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0 + ) + ) @staticmethod def _valid_indexing_types() -> tuple[IndexingLiteral, ...]: @@ -214,6 +221,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "static_ranges", "load_eviction_policies", "indexing", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) @@ -224,6 +232,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "load_eviction_policies", self.load_eviction_policies.default() ) config.setdefault("indexing", self.indexing.default()) + config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default()) # TODO(jansel): include num_ctas and max_nreg for name, values in (("pid_type", VALID_PID_TYPES),): @@ -297,6 +306,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "indexing": fn(self.indexing), "pid_type": fn(EnumFragment(self.allowed_pid_types)), "load_eviction_policies": fn(self.load_eviction_policies), + "epilogue_subtiling": fn(self.epilogue_subtiling), } # Add tunable parameters config.update( @@ -316,9 +326,11 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "static_ranges", "load_eviction_policies", "indexing", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) + self.normalize(config) return helion.Config(**config) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index 1d13d8884..616874197 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -40,6 +40,7 @@ def __init__( num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, indexing: IndexingLiteral | list[IndexingLiteral] | None = None, + epilogue_subtiling: list[int] | None = None, # For user-defined properties **kwargs: object, ) -> None: @@ -68,6 +69,7 @@ def __init__( indexing=["pointer", "block_ptr", "tensor_descriptor"] - Empty/omitted (all loads/stores default to "pointer") Valid strategies: "pointer", "tensor_descriptor", "block_ptr" + epilogue_subtiling: Whether to use subtiling for epilogue. **kwargs: Additional user-defined configuration parameters. """ self.config = {} @@ -88,6 +90,7 @@ def __init__( "num_stages": num_stages, "indexing": indexing, "pid_type": pid_type, + "epilogue_subtiling": epilogue_subtiling, } for key, value in core_props.items(): if value is not None: @@ -217,6 +220,10 @@ def indexing(self) -> IndexingLiteral | list[IndexingLiteral]: "IndexingLiteral | list[IndexingLiteral]", self.config.get("indexing", []) ) + @property + def epilogue_subtiling(self) -> list[int]: + return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value] + def _to_hashable(x: object) -> object: if isinstance(x, list): diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 54f329c7a..7d56c6a04 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -342,6 +342,11 @@ class _Settings: _env_get_bool, "HELION_ALLOW_WARP_SPECIALIZE", True ) ) + allow_epilogue_subtiling: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_ALLOW_EPILOGUE_SUBTILING", False + ) + ) debug_dtype_asserts: bool = dataclasses.field( default_factory=functools.partial( _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False @@ -400,6 +405,7 @@ class Settings(_Settings): "Accepts HELION_AUTOTUNE_CONFIG_OVERRIDES='{\"num_warps\":4}'." ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", + "allow_epilogue_subtiling": "If True, allow epilogue subtiling on TMA stores for CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": ( diff --git a/scripts/dictionary.txt b/scripts/dictionary.txt index cc8bd9c53..faf1c0bab 100644 --- a/scripts/dictionary.txt +++ b/scripts/dictionary.txt @@ -1,2 +1,4 @@ NotIn readd +numer +subtile diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index 67c9e5821..e2a1b080d 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -2,40 +2,40 @@ This file is automatically generated by assertExpectedJournal calls in test_auto Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestAutotuner.test_config_fragment0) -helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) -helion.Config(block_sizes=[32, 128, 64], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[64, 16, 64], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=32, pid_type='persistent_blocked', range_flattens=[False, False], range_multi_buffers=[True, None], range_num_stages=[0, 0], range_unroll_factors=[4, 3], range_warp_specializes=[False, None]) -helion.Config(block_sizes=[16, 64, 512], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[4], load_eviction_policies=['last', ''], loop_orders=[[1, 0]], num_stages=3, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[4, 0], range_unroll_factors=[4, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 16, 32], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[True, None], range_num_stages=[3, 4], range_unroll_factors=[2, 0], range_warp_specializes=[False, None]) -helion.Config(block_sizes=[16, 128, 16], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[64], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False, True], range_multi_buffers=[None, False], range_num_stages=[3, 0], range_unroll_factors=[0, 1], range_warp_specializes=[True, None]) -helion.Config(block_sizes=[16, 16, 16], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[2, 2], range_unroll_factors=[0, 4], range_warp_specializes=[False, False]) -helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=6, num_warps=4, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[True, None], range_num_stages=[1, 1], range_unroll_factors=[2, 0], range_warp_specializes=[False, True]) -helion.Config(block_sizes=[64, 64, 16], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[2], load_eviction_policies=['', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[None, False], range_multi_buffers=[None, True], range_num_stages=[2, 4], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 16, 16], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=7, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[None, None], range_num_stages=[4, 0], range_unroll_factors=[1, 3], range_warp_specializes=[True, None]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[32, 128, 64], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[64, 16, 64], epilogue_subtiling=[], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=32, pid_type='persistent_blocked', range_flattens=[False, False], range_multi_buffers=[True, None], range_num_stages=[0, 0], range_unroll_factors=[4, 3], range_warp_specializes=[False, None]) +helion.Config(block_sizes=[16, 64, 512], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[4], load_eviction_policies=['last', ''], loop_orders=[[1, 0]], num_stages=3, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[4, 0], range_unroll_factors=[4, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 32], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[True, None], range_num_stages=[3, 4], range_unroll_factors=[2, 0], range_warp_specializes=[False, None]) +helion.Config(block_sizes=[16, 128, 16], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[64], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False, True], range_multi_buffers=[None, False], range_num_stages=[3, 0], range_unroll_factors=[0, 1], range_warp_specializes=[True, None]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[2, 2], range_unroll_factors=[0, 4], range_warp_specializes=[False, False]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=6, num_warps=4, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[True, None], range_num_stages=[1, 1], range_unroll_factors=[2, 0], range_warp_specializes=[False, True]) +helion.Config(block_sizes=[64, 64, 16], epilogue_subtiling=[], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[2], load_eviction_policies=['', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[None, False], range_multi_buffers=[None, True], range_num_stages=[2, 4], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=7, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[None, None], range_num_stages=[4, 0], range_unroll_factors=[1, 3], range_warp_specializes=[True, None]) --- assertExpectedJournal(TestAutotuner.test_config_fragment1) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 256, 256], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 64, 128], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 1, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[3], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 32, 4], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 2, 1], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=5, num_warps=2, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[3], range_warp_specializes=[False]) -helion.Config(block_sizes=[2, 32, 64], flatten_loops=[True], indexing=['tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[2], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 8, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1, 2]], num_stages=8, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 16, 16], flatten_loops=[False], indexing=['pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[False]) -helion.Config(block_sizes=[4, 1, 2], flatten_loops=[True], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[True]) +helion.Config(block_sizes=[8, 16, 16], epilogue_subtiling=[], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 256, 256], epilogue_subtiling=[], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 64, 128], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[8, 1, 16], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[3], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 32, 4], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 2, 1], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=5, num_warps=2, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[3], range_warp_specializes=[False]) +helion.Config(block_sizes=[2, 32, 64], epilogue_subtiling=[], flatten_loops=[True], indexing=['tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[2], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 8, 16], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1, 2]], num_stages=8, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 16, 16], epilogue_subtiling=[], flatten_loops=[False], indexing=['pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[False]) +helion.Config(block_sizes=[4, 1, 2], epilogue_subtiling=[], flatten_loops=[True], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[True]) --- assertExpectedJournal(TestAutotuner.test_config_warp_specialize_unroll) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 256, 256], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 64, 128], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 1, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 32, 4], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 2, 1], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=5, num_warps=2, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 32, 64], flatten_loops=[True], indexing=['tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 8, 16], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1, 2]], num_stages=8, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 16, 16], flatten_loops=[False], indexing=['pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 1, 2], flatten_loops=[True], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[8, 16, 16], epilogue_subtiling=[], flatten_loops=[False], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 256, 256], epilogue_subtiling=[], flatten_loops=[False], indexing=['tensor_descriptor', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[4], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 64, 128], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[1], load_eviction_policies=['first', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[8, 1, 16], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 32, 4], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 2, 1], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'pointer', 'pointer'], l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=5, num_warps=2, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 32, 64], epilogue_subtiling=[], flatten_loops=[True], indexing=['tensor_descriptor', 'pointer', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 8, 16], epilogue_subtiling=[], flatten_loops=[True], indexing=['pointer', 'tensor_descriptor', 'pointer'], l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1, 2]], num_stages=8, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 16, 16], epilogue_subtiling=[], flatten_loops=[False], indexing=['pointer', 'tensor_descriptor', 'tensor_descriptor'], l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 1, 2], epilogue_subtiling=[], flatten_loops=[True], indexing=['tensor_descriptor', 'tensor_descriptor', 'pointer'], l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=7, num_warps=32, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) --- assertExpectedJournal(TestAutotuner.test_save_load_config) { diff --git a/test/test_autotuner.py b/test/test_autotuner.py index ce2d4abdb..566ba55fc 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -201,9 +201,11 @@ def test_config_fragment0(self): torch.randn([512, 512], device=DEVICE), torch.randn([512, 512], device=DEVICE), ) + examples_matmul.settings.allow_epilogue_subtiling = False spec = examples_matmul.bind(args).config_spec configs = ConfigGeneration(spec).random_population(10) self.assertExpectedJournal("\n".join(map(repr, configs))) + examples_matmul.settings.allow_epilogue_subtiling = True @patch( "helion.autotuner.config_generation.warps_to_threads", diff --git a/test/test_debug_utils.expected b/test/test_debug_utils.expected index 52a62565e..b461f698c 100644 --- a/test/test_debug_utils.expected +++ b/test/test_debug_utils.expected @@ -8,7 +8,7 @@ import helion.language as hl import torch from torch._dynamo.testing import rand_strided -@helion.kernel(config=helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) +@helion.kernel(config=helion.Config(block_sizes=[32], epilogue_subtiling=[], indexing=['pointer', 'pointer'], load_eviction_policies=[''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True) def kernel(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.shape[0] diff --git a/test/test_examples.expected b/test/test_examples.expected index 4e2b2cc7d..4cf56dcc8 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -6682,15 +6682,27 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] from __future__ import annotations import torch +import helion import triton import triton.language as tl from torch._inductor.runtime import triton_helpers from helion.runtime import default_launcher as _default_launcher import test.test_examples as _global_source0 +# src[matmul.py:N]: def matmul( +# src[matmul.py:N]: x: Tensor, +# src[matmul.py:N]: y: Tensor, +# src[matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() @triton.jit def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [1024, 1024], [1024, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + out_desc = tl.make_tensor_descriptor(out, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + out_desc_1 = tl.make_tensor_descriptor(out, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1 // 2]) # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) @@ -6711,14 +6723,474 @@ def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con acc_copy = acc acc_copy_0 = acc_copy # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - load = tl.load(tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') - load_1 = tl.load(tl.make_block_ptr(y, [1024, 1024], [1024, 1], [offset_2, offset_1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_2, acc_3 = tl.split(acc_0) + v_0 = tl.full([], 0, tl.int32) + v_1 = triton_helpers.maximum(v_0, acc_2) + v_2 = tl.full([], 0, tl.int32) + v_3 = triton_helpers.maximum(v_2, acc_3) + v_4 = tl.cast(v_1, tl.float16) + v_5 = tl.cast(v_3, tl.float16) + out_desc_1.store([offset_0, offset_1], v_4) + out_desc_1.store([offset_0, offset_1 + _BLOCK_SIZE_1 // 2], v_5) + +def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): + """ + Performs matrix multiplication of x and y with an optional epilogue function. + Args: + x (Tensor): Left matrix of shape [m, k]. + y (Tensor): Right matrix of shape [k, n]. + epilogue (Callable, optional): Function applied to the accumulator and tile indices + after the matmul. Defaults to identity (no change). + Returns: + Tensor: Resulting matrix of shape [m, n]. + """ + # src[matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[matmul.py:N]: out = torch.empty( + # src[matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 16 + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) + # src[matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamples.test_template_via_closure2_subtile_size_0) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import test.test_examples as _global_source0 +# src[matmul.py:N]: def matmul( +# src[matmul.py:N]: x: Tensor, +# src[matmul.py:N]: y: Tensor, +# src[matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [1024, 1024], [1024, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + out_desc = tl.make_tensor_descriptor(out, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 64 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 64 + group_size_m = min(num_pid_m - first_pid_m, 64) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) v_0 = tl.full([], 0, tl.int32) v_1 = triton_helpers.maximum(v_0, acc) v_2 = tl.cast(v_1, tl.float16) - tl.store(tl.make_block_ptr(out, [1024, 1024], [1024, 1], [offset_0, offset_1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), v_2, boundary_check=[0, 1]) + out_desc.store([offset_0, offset_1], v_2) + +def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): + """ + Performs matrix multiplication of x and y with an optional epilogue function. + Args: + x (Tensor): Left matrix of shape [m, k]. + y (Tensor): Right matrix of shape [k, n]. + epilogue (Callable, optional): Function applied to the accumulator and tile indices + after the matmul. Defaults to identity (no change). + Returns: + Tensor: Resulting matrix of shape [m, n]. + """ + # src[matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[matmul.py:N]: out = torch.empty( + # src[matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 16 + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) + # src[matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamples.test_template_via_closure2_subtile_size_2) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import test.test_examples as _global_source0 +# src[matmul.py:N]: def matmul( +# src[matmul.py:N]: x: Tensor, +# src[matmul.py:N]: y: Tensor, +# src[matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [1024, 1024], [1024, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + out_desc = tl.make_tensor_descriptor(out, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + out_desc_1 = tl.make_tensor_descriptor(out, [1024, 1024], [1024, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1 // 2]) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 64 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 64 + group_size_m = min(num_pid_m - first_pid_m, 64) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_2, acc_3 = tl.split(acc_0) + v_0 = tl.full([], 0, tl.int32) + v_1 = triton_helpers.maximum(v_0, acc_2) + v_2 = tl.full([], 0, tl.int32) + v_3 = triton_helpers.maximum(v_2, acc_3) + v_4 = tl.cast(v_1, tl.float16) + v_5 = tl.cast(v_3, tl.float16) + out_desc_1.store([offset_0, offset_1], v_4) + out_desc_1.store([offset_0, offset_1 + _BLOCK_SIZE_1 // 2], v_5) + +def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): + """ + Performs matrix multiplication of x and y with an optional epilogue function. + Args: + x (Tensor): Left matrix of shape [m, k]. + y (Tensor): Right matrix of shape [k, n]. + epilogue (Callable, optional): Function applied to the accumulator and tile indices + after the matmul. Defaults to identity (no change). + Returns: + Tensor: Resulting matrix of shape [m, n]. + """ + # src[matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[matmul.py:N]: out = torch.empty( + # src[matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 16 + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) + # src[matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamples.test_template_via_closure3) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import test.test_examples as _global_source0 + +@triton.jit +def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 64 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 64 + group_size_m = min(num_pid_m - first_pid_m, 64) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None) + load_1 = tl.load(y + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_2, acc_3 = tl.split(acc_0) + v_0 = tl.full([], 0, tl.int32) + v_1 = triton_helpers.maximum(v_0, acc_2) + v_2 = tl.full([], 0, tl.int32) + v_3 = triton_helpers.maximum(v_2, acc_3) + v_4 = 1.0 + v_5 = v_1 + v_4 + v_6 = 1.0 + v_7 = v_3 + v_6 + v_8 = tl.sigmoid(tl.cast(v_5, tl.float32)) + v_9 = tl.sigmoid(tl.cast(v_7, tl.float32)) + v_10 = tl.cast(v_8, tl.float16) + v_11 = tl.cast(v_9, tl.float16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + tl.store(out + (indices_0[:, None] * 1024 + indices_n_0[None, :] * 1), v_10, None) + tl.store(out + (indices_0[:, None] * 1024 + indices_n_1[None, :] * 1), v_11, None) + +def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): + """ + Performs matrix multiplication of x and y with an optional epilogue function. + Args: + x (Tensor): Left matrix of shape [m, k]. + y (Tensor): Right matrix of shape [k, n]. + epilogue (Callable, optional): Function applied to the accumulator and tile indices + after the matmul. Defaults to identity (no change). + Returns: + Tensor: Resulting matrix of shape [m, n]. + """ + # src[matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[matmul.py:N]: out = torch.empty( + # src[matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 16 + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) + # src[matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamples.test_template_via_closure3_subtile_size_0) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import test.test_examples as _global_source0 + +@triton.jit +def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 64 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 64 + group_size_m = min(num_pid_m - first_pid_m, 64) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None) + load_1 = tl.load(y + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + v_0 = tl.full([], 0, tl.int32) + v_1 = triton_helpers.maximum(v_0, acc) + v_2 = 1.0 + v_3 = v_1 + v_2 + v_4 = tl.sigmoid(tl.cast(v_3, tl.float32)) + v_5 = tl.cast(v_4, tl.float16) + tl.store(out + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_5, None) + +def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): + """ + Performs matrix multiplication of x and y with an optional epilogue function. + Args: + x (Tensor): Left matrix of shape [m, k]. + y (Tensor): Right matrix of shape [k, n]. + epilogue (Callable, optional): Function applied to the accumulator and tile indices + after the matmul. Defaults to identity (no change). + Returns: + Tensor: Resulting matrix of shape [m, n]. + """ + # src[matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[matmul.py:N]: out = torch.empty( + # src[matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 64 + _BLOCK_SIZE_1 = 64 + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 16 + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N-N]: ... + _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) + # src[matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestExamples.test_template_via_closure3_subtile_size_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from helion.runtime import default_launcher as _default_launcher + +import test.test_examples as _global_source0 + +@triton.jit +def _helion_matmul(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(1024, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(1024, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 64 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 64 + group_size_m = min(num_pid_m - first_pid_m, 64) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[matmul.py:N]: for tile_k in hl.tile(k): + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[matmul.py:N]: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None) + load_1 = tl.load(y + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None) + acc = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + # src[matmul.py:N]: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_2, acc_3 = tl.split(acc_0) + v_0 = tl.full([], 0, tl.int32) + v_1 = triton_helpers.maximum(v_0, acc_2) + v_2 = tl.full([], 0, tl.int32) + v_3 = triton_helpers.maximum(v_2, acc_3) + v_4 = 1.0 + v_5 = v_1 + v_4 + v_6 = 1.0 + v_7 = v_3 + v_6 + v_8 = tl.sigmoid(tl.cast(v_5, tl.float32)) + v_9 = tl.sigmoid(tl.cast(v_7, tl.float32)) + v_10 = tl.cast(v_8, tl.float16) + v_11 = tl.cast(v_9, tl.float16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + tl.store(out + (indices_0[:, None] * 1024 + indices_n_0[None, :] * 1), v_10, None) + tl.store(out + (indices_0[:, None] * 1024 + indices_n_1[None, :] * 1), v_11, None) def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]], Tensor]=lambda acc, tile: acc, *, _launcher=_default_launcher): """ diff --git a/test/test_examples.py b/test/test_examples.py index ae98d007d..0edcb907d 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -6,6 +6,8 @@ from packaging import version import torch import torch.nn.functional as F +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.common_utils import parametrize import helion from helion import _compat @@ -231,12 +233,16 @@ def test_template_via_closure0(self): torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), lambda acc, tile: torch.relu(acc + bias[tile]), ) + + # Disallow epilogue subtiling, currently unable to handle bias + # addition self.assertExpectedJournal( check_example( "matmul", args, torch.relu(args[0] @ args[1] + bias), fn_name="matmul", + allow_epilogue_subtiling=False, block_sizes=[64, 64, 16], loop_orders=[[0, 1]], num_warps=2, @@ -255,12 +261,16 @@ def test_template_via_closure1(self): torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), lambda acc, tile: torch.relu(acc + bias[tile]), ) + + # Disallow epilogue subtiling, currently unable to handle bias + # addition self.assertExpectedJournal( check_example( "matmul", args, torch.relu(args[0] @ args[1] + bias), fn_name="matmul", + allow_epilogue_subtiling=False, block_sizes=[64, 64, 16], loop_orders=[[0, 1]], num_warps=2, @@ -270,8 +280,14 @@ def test_template_via_closure1(self): ) ) - @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) - def test_template_via_closure2(self): + @parametrize("subtile_size", [0, 2]) + @patch.object(_compat, "_supports_tensor_descriptor", lambda: True) + def test_template_via_closure2(self, subtile_size: int): + # Skip subtiling for non blackwell + if subtile_size == 2 and not ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 + ): + return args = ( torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), @@ -282,13 +298,45 @@ def test_template_via_closure2(self): "matmul", args, torch.relu(args[0] @ args[1]), + allow_epilogue_subtiling=True, fn_name="matmul", block_sizes=[64, 64, 16], loop_orders=[[0, 1]], num_warps=2, num_stages=4, - indexing="block_ptr", + indexing="tensor_descriptor", l2_grouping=64, + epilogue_subtiling=[subtile_size], + ) + ) + + @parametrize("subtile_size", [0, 2]) + @patch.object(_compat, "_supports_tensor_descriptor", lambda: True) + def test_template_via_closure3(self, subtile_size: int): + # Skip subtiling for non blackwell + if subtile_size == 2 and not ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10 + ): + return + args = ( + torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), + torch.randn([1024, 1024], device=DEVICE, dtype=torch.float16), + lambda x, _: torch.nn.functional.sigmoid(torch.nn.functional.relu(x) + 1.0), + ) + self.assertExpectedJournal( + check_example( + "matmul", + args, + torch.sigmoid(torch.relu(args[0] @ args[1]) + 1.0), + allow_epilogue_subtiling=True, + fn_name="matmul", + block_sizes=[64, 64, 16], + loop_orders=[[0, 1]], + num_warps=2, + num_stages=4, + indexing="pointer", + l2_grouping=64, + epilogue_subtiling=[subtile_size], ) ) @@ -1813,5 +1861,7 @@ def test_grpo_loss_bwd(self): ) +instantiate_parametrized_tests(TestExamples) + if __name__ == "__main__": unittest.main() diff --git a/test/test_matmul.expected b/test/test_matmul.expected index 0284e8237..721420053 100644 --- a/test/test_matmul.expected +++ b/test/test_matmul.expected @@ -265,6 +265,460 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] # src[matmul.py:N]: return out return out +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_pointer_mask_subtile_0) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(130, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(130, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 130 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 130 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 130, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < 130 + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 130 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * 130 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + v_2 = tl.cast(acc, tl.bfloat16) + tl.store(out + (indices_0[:, None] * 130 + indices_1[None, :] * 1), v_2, mask_0[:, None] & mask_1[None, :]) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(130, _BLOCK_SIZE_0) * triton.cdiv(130, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_pointer_mask_subtile_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(130, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(130, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 130 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < 130 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 130, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < 130 + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 130 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * 130 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + mask_n_0 = indices_n_0 < 130 + mask_n_1 = indices_n_1 < 130 + tl.store(out + (indices_0[:, None] * 130 + indices_n_0[None, :] * 1), v_2, mask_0[:, None] & mask_n_0[None, :]) + tl.store(out + (indices_0[:, None] * 130 + indices_n_1[None, :] * 1), v_3, mask_0[:, None] & mask_n_1[None, :]) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(130, _BLOCK_SIZE_0) * triton.cdiv(130, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_subtile_0) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +# src[test_matmul.py:N]: def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# src[test_matmul.py:N]: m, k = x.size() +# src[test_matmul.py:N]: k2, n = y.size() +# src[test_matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [128, 128], [128, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + out_desc = tl.make_tensor_descriptor(out, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + v_2 = tl.cast(acc, tl.bfloat16) + out_desc.store([offset_0, offset_1], v_2) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_subtile_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), None) + load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + indices_n_0 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + indices_n_1 = (offset_1 + _BLOCK_SIZE_1 // 2 + tl.arange(0, _BLOCK_SIZE_1 // 2)).to(tl.int32) + tl.store(out + (indices_0[:, None] * 128 + indices_n_0[None, :] * 1), v_2, None) + tl.store(out + (indices_0[:, None] * 128 + indices_n_1[None, :] * 1), v_3, None) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_tensor_descriptor_subtile_0) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +# src[test_matmul.py:N]: def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# src[test_matmul.py:N]: m, k = x.size() +# src[test_matmul.py:N]: k2, n = y.size() +# src[test_matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [128, 128], [128, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + out_desc = tl.make_tensor_descriptor(out, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + v_2 = tl.cast(acc, tl.bfloat16) + out_desc.store([offset_0, offset_1], v_2) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + +--- assertExpectedJournal(TestMatmul.test_matmul_epilogue_subtile_tensor_descriptor_subtile_2) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +# src[test_matmul.py:N]: def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: +# src[test_matmul.py:N]: m, k = x.size() +# src[test_matmul.py:N]: k2, n = y.size() +# src[test_matmul.py:N-N]: ... +helion.runtime.set_triton_allocator() + +@triton.jit +def _helion_matmul_static_shapes(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + x_desc = tl.make_tensor_descriptor(x, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_2]) + y_desc = tl.make_tensor_descriptor(y, [128, 128], [128, 1], [_BLOCK_SIZE_2, _BLOCK_SIZE_1]) + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + out_desc = tl.make_tensor_descriptor(out, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) + out_desc_1 = tl.make_tensor_descriptor(out, [128, 128], [128, 1], [_BLOCK_SIZE_0, _BLOCK_SIZE_1 // 2]) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + num_pid_m = tl.cdiv(128, _BLOCK_SIZE_0) + num_pid_n = tl.cdiv(128, _BLOCK_SIZE_1) + inner_2d_pid = tl.program_id(0) + num_pid_in_group = 4 * num_pid_n + group_id = inner_2d_pid // num_pid_in_group + first_pid_m = group_id * 4 + group_size_m = min(num_pid_m - first_pid_m, 4) + pid_0 = first_pid_m + inner_2d_pid % num_pid_in_group % group_size_m + pid_1 = inner_2d_pid % num_pid_in_group // group_size_m + offset_0 = pid_0 * _BLOCK_SIZE_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + acc_copy = acc + acc_copy_0 = acc_copy + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + load = x_desc.load([offset_0, offset_2]) + load_1 = y_desc.load([offset_2, offset_1]) + mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16) + v_0 = tl.cast(mm, tl.float32) + acc = acc_copy_0 + v_0 + # src[test_matmul.py:N]: out[tile_m, tile_n] = acc + acc_0 = tl.reshape(acc, [_BLOCK_SIZE_0, 2, _BLOCK_SIZE_1 // 2]).permute(0, 2, 1) + acc_1, acc_2 = tl.split(acc_0) + v_2 = tl.cast(acc_1, tl.bfloat16) + v_3 = tl.cast(acc_2, tl.bfloat16) + out_desc_1.store([offset_0, offset_1], v_2) + out_desc_1.store([offset_0, offset_1 + _BLOCK_SIZE_1 // 2], v_3) + +def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + # src[test_matmul.py:N]: m, k = x.size() + m, k = x.size() + # src[test_matmul.py:N]: k2, n = y.size() + k2, n = y.size() + # src[test_matmul.py:N]: assert k == k2, f"size mismatch {k} != {k2}" + assert k == k2, f'size mismatch {k} != {k2}' + # src[test_matmul.py:N]: out = torch.empty( + # src[test_matmul.py:N]: [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + # src[test_matmul.py:N]: ) + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N]: acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + _BLOCK_SIZE_2 = 32 + # src[test_matmul.py:N]: for tile_m, tile_n in hl.tile([m, n]): + # src[test_matmul.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # src[test_matmul.py:N]: for tile_k in hl.tile(k): + # src[test_matmul.py:N-N]: ... + _launcher(_helion_matmul_static_shapes, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1) + # src[test_matmul.py:N]: return out + return out + --- assertExpectedJournal(TestMatmul.test_matmul_packed_rhs) from __future__ import annotations diff --git a/test/test_matmul.py b/test/test_matmul.py index fa76013e4..4fd36581b 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -5,6 +5,8 @@ from unittest.mock import patch import torch +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.common_utils import parametrize import helion from helion import Config @@ -56,20 +58,23 @@ def matmul_without_addmm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out -@helion.kernel(static_shapes=True) -def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - m, k = x.size() - k2, n = y.size() - assert k == k2, f"size mismatch {k} != {k2}" - out = torch.empty( - [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device - ) - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = acc - return out +def matmul_static_shapes_wrapper(epilogue_subtiling: bool = False): + @helion.kernel(static_shapes=True, allow_epilogue_subtiling=epilogue_subtiling) + def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k2, n = y.size() + assert k == k2, f"size mismatch {k} != {k2}" + out = torch.empty( + [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device + ) + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + return out + + return matmul_static_shapes class TestMatmul(RefEagerTestBase, TestCase): @@ -153,7 +158,7 @@ def test_matmul_static_shapes0(self): torch.randn([128, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -168,7 +173,7 @@ def test_matmul_static_shapes1(self): torch.randn([128, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -182,7 +187,7 @@ def test_matmul_static_shapes2(self): torch.randn([127, 128], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -196,7 +201,7 @@ def test_matmul_static_shapes3(self): torch.randn([128, 127], device=DEVICE, dtype=torch.float32), ) code, output = code_and_output( - matmul_static_shapes, + matmul_static_shapes_wrapper(), args, block_sizes=[16, 16, 16], l2_grouping=4, @@ -341,6 +346,50 @@ def matmul_with_packed_b( torch.testing.assert_close(C, expected, atol=5e-2, rtol=1e-3) self.assertExpectedJournal(code) + @unittest.skipIf( + DEVICE.type != "cuda" or torch.cuda.get_device_capability() < (10, 0), + "Epilogue Subtiling requires CUDA compute capability >= 10.0", + ) + @parametrize("subtile", (0, 2)) + def test_matmul_epilogue_subtile_tensor_descriptor(self, subtile: int): + args = ( + torch.randn([128, 128], device=DEVICE, dtype=torch.bfloat16), + torch.randn([128, 128], device=DEVICE, dtype=torch.bfloat16), + ) + code, output = code_and_output( + matmul_static_shapes_wrapper(epilogue_subtiling=True), + args, + block_sizes=[32, 32, 32], + l2_grouping=4, + indexing="tensor_descriptor", + epilogue_subtiling=[subtile], + ) + torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) + self.assertExpectedJournal(code) + + @unittest.skipIf( + DEVICE.type != "cuda" or torch.cuda.get_device_capability() < (10, 0), + "Epilogue Subtiling requires CUDA compute capability >= 10.0", + ) + @parametrize("subtile", (0, 2)) + def test_matmul_epilogue_subtile_pointer_mask(self, subtile: int): + args = ( + torch.randn([130, 130], device=DEVICE, dtype=torch.bfloat16), + torch.randn([130, 130], device=DEVICE, dtype=torch.bfloat16), + ) + code, output = code_and_output( + matmul_static_shapes_wrapper(epilogue_subtiling=True), + args, + block_sizes=[32, 32, 32], + l2_grouping=4, + indexing="pointer", + epilogue_subtiling=[subtile], + ) + torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) + self.assertExpectedJournal(code) + + +instantiate_parametrized_tests(TestMatmul) if __name__ == "__main__": unittest.main() diff --git a/test/test_register_tunable.expected b/test/test_register_tunable.expected index a71a84290..e31b9ba26 100644 --- a/test/test_register_tunable.expected +++ b/test/test_register_tunable.expected @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) -helion.Config(block_sizes=[32], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) +helion.Config(block_sizes=[32], epilogue_subtiling=[], indexing=['pointer', 'pointer'], load_eviction_policies=[''], multiplier=3, num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) from __future__ import annotations diff --git a/test/test_type_propagation.expected b/test/test_type_propagation.expected index dfe5229cb..a6672efa4 100644 --- a/test/test_type_propagation.expected +++ b/test/test_type_propagation.expected @@ -1025,35 +1025,35 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] Tensor: Resulting matrix of shape [m, n]. """ m, k = - # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) + # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x') x.size() k2, n = - # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) + # Call: SequenceType((LiteralType(512), LiteralType(512))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='y'), key='size') # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y') y.size() assert - # Compare: LiteralType(True) SourceOrigin(location=) - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Compare: LiteralType(True) SourceOrigin(location=) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) k == - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) k2, - # JoinedStr: str SourceOrigin(location=) + # JoinedStr: str SourceOrigin(location=) f'size mismatch {k} != {k2}' out = - # Call: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Call: TensorType([512, 512], torch.float32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.empty( - # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) + # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) [ - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) m, - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) n], dtype= - # Call: LiteralType(torch.float32) SourceOrigin(location=) + # Call: LiteralType(torch.float32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.promote_types) AttributeOrigin(value=GlobalOrigin(name='torch'), key='promote_types') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.promote_types( @@ -1069,26 +1069,26 @@ x.device) # For: loop_type=GRID for tile_m, tile_n in - # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) + # Call: IterType(SequenceType((TileIndexType(0), TileIndexType(1)))) SourceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.tile( - # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) + # List: SequenceType([LiteralType(512), LiteralType(512)]) SourceOrigin(location=) [ - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=0) m, - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) n]): acc = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(zeros) AttributeOrigin(value=GlobalOrigin(name='hl'), key='zeros') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.zeros( - # List: SequenceType([TileIndexType(0), TileIndexType(1)]) DeviceOrigin(location=) + # List: SequenceType([TileIndexType(0), TileIndexType(1)]) DeviceOrigin(location=) [ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n], dtype= # Attribute: LiteralType(torch.float32) AttributeOrigin(value=GlobalOrigin(name='torch'), key='float32') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') @@ -1096,61 +1096,61 @@ torch.float32) # For: loop_type=DEVICE for tile_k in - # Call: IterType(TileIndexType(2)) DeviceOrigin(location=) + # Call: IterType(TileIndexType(2)) DeviceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') hl.tile( - # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) + # Name: LiteralType(512) GetItemOrigin(value=SourceOrigin(location=), key=1) k): acc = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.addmm) AttributeOrigin(value=GlobalOrigin(name='torch'), key='addmm') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') torch.addmm( - # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) acc, - # Subscript: TensorType([block_size_0, block_size_2], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_2], torch.float32) DeviceOrigin(location=) # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x') x[ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(2) DeviceOrigin(location=) + # Name: TileIndexType(2) DeviceOrigin(location=) tile_k], - # Subscript: TensorType([block_size_2, block_size_1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_2, block_size_1], torch.float32) DeviceOrigin(location=) # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y') y[ - # Name: TileIndexType(2) DeviceOrigin(location=) + # Name: TileIndexType(2) DeviceOrigin(location=) tile_k, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n]) - # Subscript: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) - # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) out[ - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n] = - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) # Name: CallableType() ArgumentOrigin(name='epilogue') epilogue( - # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Name: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) acc, - # Tuple: SequenceType((TileIndexType(0), TileIndexType(1))) DeviceOrigin(location=) + # Tuple: SequenceType((TileIndexType(0), TileIndexType(1))) DeviceOrigin(location=) ( - # Name: TileIndexType(0) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) tile_m, - # Name: TileIndexType(1) SourceOrigin(location=) + # Name: TileIndexType(1) SourceOrigin(location=) tile_n)) return - # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) + # Name: TensorType([512, 512], torch.float32) SourceOrigin(location=) out def for_loop_0(arg0_1: "f32[u0, u1]"): - # File: .../matmul.py:N in matmul, code: for tile_k in hl.tile(k): + # File: .../matmul.py:62 in matmul, code: for tile_k in hl.tile(k): _new_var: "f32[u0, u1]" = helion_language__tracing_ops__new_var(arg0_1) - # File: .../matmul.py:N in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + # File: .../matmul.py:63 in matmul, code: acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) x: "f32[512, 512]" = helion_language__tracing_ops__host_tensor('x') sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(arg0_1, 0) block_size_2: "Sym(u2)" = helion_language__tracing_ops__get_symnode('block_size_2') @@ -1162,17 +1162,17 @@ def for_loop_0(arg0_1: "f32[u0, u1]"): return [acc] def root_graph_1(): - # File: .../matmul.py:N in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + # File: .../matmul.py:61 in matmul, code: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') acc: "f32[u0, u1]" = helion_language_creation_ops_full([block_size_0, block_size_1], 0.0, torch.float32, None) - # File: .../matmul.py:N in matmul, code: for tile_k in hl.tile(k): + # File: .../matmul.py:62 in matmul, code: for tile_k in hl.tile(k): _for_loop = helion_language__tracing_ops__for_loop(0, [0], [512], [acc]) getitem: "f32[u0, u1]" = _for_loop[0]; _for_loop = None _phi: "f32[u0, u1]" = helion_language__tracing_ops__phi(acc, getitem); acc = getitem = None - # File: .../matmul.py:N in matmul, code: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) + # File: .../matmul.py:64 in matmul, code: out[tile_m, tile_n] = epilogue(acc, (tile_m, tile_n)) out: "f32[512, 512]" = helion_language__tracing_ops__host_tensor('out') store = helion_language_memory_ops_store(out, [block_size_0, block_size_1], _phi, None); out = block_size_0 = block_size_1 = _phi = store = None return None