diff --git a/src/xtc/backends/mlir/MlirCompilerPasses.py b/src/xtc/backends/mlir/MlirCompilerPasses.py index fee186bd..fe6120de 100644 --- a/src/xtc/backends/mlir/MlirCompilerPasses.py +++ b/src/xtc/backends/mlir/MlirCompilerPasses.py @@ -34,6 +34,7 @@ sdist_transform = None pass +from xtc.itf.schd.scheduler import ROOT_SEP from xtc.utils.ext_tools import transform_opts from .MlirProgram import RawMlirProgram @@ -293,7 +294,11 @@ def _generate_node_scheduling( schedule=schedule, root=loop_name, sched_state=sched_state ) continue - + axis_split = split_state.loop_dim_by_split.get(root) + if axis_split is not None and not ( + schedule.is_base(loop_name) or schedule.is_tile(loop_name) + ): + loop_name = root + ROOT_SEP + axis_split # Bufferization if loop_name in schedule.distributed_buffers.keys(): self._distribute_buffer( @@ -340,6 +345,7 @@ def _generate_tiling_insns( state_of_tiling: dict[str, int] = {dim: 1 for dim in schedule.dims} candidate_state_of_tiling = state_of_tiling.copy() previous_root = "" + split_state = SplitState(schedule.splits, previous_root) for loc_root, permutation in reversed(schedule.permutation.items()): if len(loc_root) == len(previous_root): # Reset the view on the state of tiling (we are jumping into @@ -348,11 +354,14 @@ def _generate_tiling_insns( else: # Update the state of tiling state_of_tiling = candidate_state_of_tiling.copy() - for loop in reversed(permutation): # The loop needs to be base or tile if not (schedule.is_tile(loop) or schedule.is_base(loop)): - continue + axis_split = split_state.loop_dim_by_split.get(loc_root) + if axis_split is not None: + loop = loc_root + ROOT_SEP + axis_split + else: + continue # Fetch the dimension knowledge dim_of_loop = schedule.dim_of_tile(loop) @@ -472,7 +481,7 @@ def _unroll( for dim_name in reversed(permutation): if ( dim_name in schedule.unrolling - and not dim_name in schedule.vectorization + and dim_name not in schedule.vectorization ): assert self._named_sequence is not None loop_unroll( diff --git a/src/xtc/backends/mlir/MlirNodeScheduler.py b/src/xtc/backends/mlir/MlirNodeScheduler.py index 27f43a42..c991bc0c 100644 --- a/src/xtc/backends/mlir/MlirNodeScheduler.py +++ b/src/xtc/backends/mlir/MlirNodeScheduler.py @@ -5,15 +5,13 @@ from typing_extensions import override from dataclasses import dataclass, asdict from pprint import pformat -from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.itf.schd.scheduler import DEFAULT_ROOT, ROOT_SEP __all__ = [ "MlirNodeScheduler", "MlirNodeSchedule", ] -ROOT_SEP = "/" - def basename(loop_name: str) -> str: return loop_name.split(ROOT_SEP)[-1] diff --git a/src/xtc/itf/schd/scheduler.py b/src/xtc/itf/schd/scheduler.py index 40033b37..3ea1febb 100644 --- a/src/xtc/itf/schd/scheduler.py +++ b/src/xtc/itf/schd/scheduler.py @@ -7,6 +7,9 @@ import xtc.itf DEFAULT_ROOT = "." +ROOT_SEP = "/" +SPLIT_LEFT_SEP = "[" +SPLIT_RIGHT_SEP = "]" class Scheduler(ABC): diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 9433b98a..63ee7052 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field import re from typing_extensions import override -from xtc.itf.schd.scheduler import Scheduler +from xtc.itf.schd.scheduler import Scheduler, ROOT_SEP, SPLIT_LEFT_SEP, SPLIT_RIGHT_SEP class ScheduleParseError(RuntimeError): @@ -60,7 +60,7 @@ class SplitDecl: def __str__(self) -> str: start_str = "" if self.start is None else str(self.start) end_str = "" if self.end is None else str(self.end) - decl = f"{self.axis}[{start_str}:{end_str}]" + decl = f"{self.axis}{SPLIT_LEFT_SEP}{start_str}:{end_str}{SPLIT_RIGHT_SEP}" return decl @@ -221,14 +221,14 @@ class ScheduleInterpreter: def __init__(self, abstract_axis: list[str]): self.abstract_axis = abstract_axis + self.root_to_dim: dict[str, str] = {} + self.dim_to_axis: dict[str, str] = {} def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest: """Interpret a schedule specification into a LoopNest.""" - return self._interpret_spec(spec, root, head=[]) + return self._interpret_spec(spec, root) - def _interpret_spec( - self, spec: ScheduleSpec, root: str, head: list[str] - ) -> LoopNest: + def _interpret_spec(self, spec: ScheduleSpec, root: str) -> LoopNest: """Interpret a schedule spec recursively.""" loop_nest = LoopNest(abstract_dims=self.abstract_axis) slice = loop_nest.build_slice(root) @@ -236,7 +236,10 @@ def _interpret_spec( # Track state during interpretation sizes: dict[str, int] = {} previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} - interchange: list[str] = list(head) + interchange: list[str] = [] + # Only the first root is not in root_to_dim + if root in self.root_to_dim: + interchange.append(self.root_to_dim[root]) for item in spec.items: if isinstance(item, SplitDecl): @@ -249,7 +252,6 @@ def _interpret_spec( elif isinstance(item, AxisDecl): loop_name = self._interpret_axis(item, interchange) self._apply_annotations(item.annotations, loop_name, sizes, slice) - # Check that all splits are complete for axis, cut in previous_cut.items(): if cut is not None and cut != 0: @@ -294,13 +296,14 @@ def _interpret_split( if axis_name not in slice.splits: slice.splits[axis_name] = {} new_dim_index = len(slice.splits[axis_name]) - new_dim_name = f"{axis_name}[{new_dim_index}]" - new_root_name = f"{root}/{new_dim_name}" + new_dim_name = f"{axis_name}{SPLIT_LEFT_SEP}{new_dim_index}{SPLIT_RIGHT_SEP}" + new_root_name = f"{root}{ROOT_SEP}{new_dim_name}" slice.splits[axis_name][new_dim_name] = x interchange.append(new_dim_name) - + self.dim_to_axis[new_dim_name] = axis_name + self.root_to_dim[new_root_name] = new_dim_name # Recursively interpret the nested schedule - inner_nest = self._interpret_spec(item.body, new_root_name, head=[axis_name]) + inner_nest = self._interpret_spec(item.body, new_root_name) loop_nest.slices += inner_nest.slices def _interpret_tile( @@ -321,7 +324,6 @@ def _interpret_tile( slice.tiles[item.axis][loop_name] = item.size sizes[loop_name] = item.size interchange.append(loop_name) - return loop_name def _interpret_axis( @@ -332,13 +334,13 @@ def _interpret_axis( """Interpret a direct axis reference. Returns the loop name.""" axis_name = item.axis self._check_axis_existence(axis_name) - # Unreachable when built from a Python dict (because keys # can't be duplicated). - if axis_name in interchange: - raise ScheduleInterpretError( - f"Axis {axis_name} is scheduled twice (or more)." - ) + for loop_name in interchange: + if self.dim_to_axis.get(loop_name, loop_name) == axis_name: + raise ScheduleInterpretError( + f"Axis {axis_name} is scheduled twice (or more)." + ) interchange.append(axis_name) return axis_name @@ -478,21 +480,23 @@ class LoopNestSlice: root: str tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int]] = field(default_factory=dict) + splits: dict[str, dict[str, int | None]] = field(default_factory=dict) interchange: list[str] = field(default_factory=list) vectorize: list[str] = field(default_factory=list) parallelize: list[str] = field(default_factory=list) unroll: dict[str, int] = field(default_factory=dict) @property - def splits_to_sizes(self) -> dict[str, int]: - splits_to_sizes: dict[str, int] = {} + def splits_to_sizes(self) -> dict[str, int | None]: + splits_to_sizes: dict[str, int | None] = {} for axis in self.splits: last_start = None for loop_name, start in reversed(self.splits[axis].items()): - if last_start is not None: + if last_start is not None and start is not None: size_of_split = last_start - start splits_to_sizes[loop_name] = size_of_split + else: + splits_to_sizes[loop_name] = None last_start = start return splits_to_sizes @@ -557,6 +561,8 @@ def _check_tiling_consistency(self) -> None: seen_axes: dict[str, int | None] = {} for sched in self.slices: for loop_name in sched.interchange: + loop_name = mapper.splits_to_axis.get(loop_name, loop_name) + if loop_name in mapper.dims: seen_axes[loop_name] = None elif loop_name in mapper.tiles_to_axis: @@ -575,7 +581,6 @@ def _check_sizes(self): current_size_of_split: dict[str, int | None] = {} for sched in self.slices: current_size_of_tile: dict[str, int] = {} - for loop_name in sched.interchange: axis = mapper.loops_to_axis[loop_name] current_sizes = ( @@ -607,7 +612,9 @@ def _check_sizes(self): loop_name=loop_name, axis=axis, ) - current_size_of_split[axis] = loop_size + current_size_of_split[loop_name] = loop_size + elif loop_name in current_size_of_split: + current_size_of_split[axis] = current_size_of_split[loop_name] if loop_name in sched.unroll: unroll_factor = sched.unroll[loop_name] @@ -618,10 +625,13 @@ def _check_sizes(self): @staticmethod def _must_be_smaller_routine( - new_size: int, current_sizes: dict[str, int | None], loop_name: str, axis: str + new_size: int | None, + current_sizes: dict[str, int | None], + loop_name: str, + axis: str, ): old_size = current_sizes[axis] - if old_size is not None and new_size > old_size: + if old_size is not None and new_size is not None and new_size > old_size: raise ScheduleValidationError( f""" Inner loop {loop_name} on axis {axis} must be smaller than outer loop. @@ -683,10 +693,8 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: # Interpret the AST into a LoopNest interpreter = ScheduleInterpreter(self.abstract_axis) loop_nest = interpreter.interpret(ast, root=node_name) - # Validate the loop nest loop_nest.check() - # Apply the schedule to the scheduler self._apply_loop_nest(loop_nest) diff --git a/tests/filecheck/schedules/test_descript_slice_bigger.py b/tests/filecheck/schedules/test_descript_slice_bigger.py new file mode 100644 index 00000000..61f2090e --- /dev/null +++ b/tests/filecheck/schedules/test_descript_slice_bigger.py @@ -0,0 +1,195 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir import Backend +from xtc.schedules.descript import descript_scheduler + +I, J, K, dtype = 50, 64, 64, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph) + +sch = impl.get_scheduler() +descript_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + 'k': {}, + 'j': {}, + 'i[:32]': { + 'i#32': {}, + 'k#32': {}, + 'j#16': {'vectorize': True}, + }, + 'i[32:]': { + 'i#18': {}, + 'k#32': {}, + 'j#16': {'vectorize': True}, + } + } +) + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_slice_first_bigger", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sch.schedule()) +evaluator = module.get_evaluator( + validate=True, +) +results, code, error = evaluator.evaluate() +print(f"CODE: {code}") +# CHECK: // -----// IR Dump Before transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<50x64xf32> {llvm.noalias}, %arg1: memref<64x64xf32> {llvm.noalias}, %arg2: memref<50x64xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<50x64xf32>) +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<50x64xf32>, memref<64x64xf32>) outs(%arg2 : memref<50x64xf32>) +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +# CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_5 "C/j" : !transform.any_op +# CHECK-NEXT: %2 = transform.structured.split %tiled_linalg_op_4 after 32 {dimension = 0 : i64} : !transform.any_op +# CHECK-NEXT: %3:2 = transform.split_handle %2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %3#0 tile_sizes [32, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i0" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %tiled_linalg_op_8 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_11 "C/i[0]/k0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +# CHECK-NEXT: %tiled_linalg_op_12, %loops_13 = transform.structured.tile_using_for %3#1 tile_sizes [18, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_13 "C/i[1]/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_14, %loops_15 = transform.structured.tile_using_for %tiled_linalg_op_12 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_15 "C/i[1]/i0" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_16, %loops_17 = transform.structured.tile_using_for %tiled_linalg_op_14 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_17 "C/i[1]/k0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_16) : (!transform.any_op) -> () +# CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %4 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %4 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: // -----// IR Dump After transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<50x64xf32> {llvm.noalias}, %arg1: memref<64x64xf32> {llvm.noalias}, %arg2: memref<50x64xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +# CHECK-NEXT: %c18 = arith.constant 18 : index +# CHECK-NEXT: %0 = ub.poison : f32 +# CHECK-NEXT: %c16 = arith.constant 16 : index +# CHECK-NEXT: %c32 = arith.constant 32 : index +# CHECK-NEXT: %c64 = arith.constant 64 : index +# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: %c0 = arith.constant 0 : index +# CHECK-NEXT: %c50 = arith.constant 50 : index +# CHECK-NEXT: %c1 = arith.constant 1 : index +# CHECK-NEXT: scf.for %arg3 = %c0 to %c50 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 64] [1, 1] : memref<50x64xf32> to memref<1x64xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c64 step %c1 { +# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x64xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[64, 1], offset: ?>>) +# CHECK-NEXT: } {"./j"} +# CHECK-NEXT: } {"./i"} +# CHECK-NEXT: scf.for %arg3 = %c0 to %c64 step %c32 { +# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [50, 32] [1, 1] : memref<50x64xf32> to memref<50x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [32, 64] [1, 1] : memref<64x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [50, 64] [1, 1] : memref<50x64xf32> to memref<50x64xf32, strided<[64, 1]>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c64 step %c16 { +# CHECK-NEXT: %subview_3 = memref.subview %subview_1[0, %arg4] [32, 16] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_4 = memref.subview %subview_2[0, %arg4] [50, 16] [1, 1] : memref<50x64xf32, strided<[64, 1]>> to memref<50x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_5 = memref.subview %subview[0, 0] [32, 32] [1, 1] : memref<50x32xf32, strided<[64, 1], offset: ?>> to memref<32x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, 0] [32, 16] [1, 1] : memref<50x16xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c32 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [32, 32] [1, 1] : memref<32x32xf32, strided<[64, 1], offset: ?>> to memref<32x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [32, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 32] [1, 1] : memref<32x32xf32, strided<[64, 1], offset: ?>> to memref<1x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg7 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_13 = memref.subview %subview_11[0, %arg7] [1, 1] [1, 1] : memref<1x32xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_14 = memref.subview %subview_3[%arg7, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[64, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +# CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +# CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: } {"C/i[0]/k0"} +# CHECK-NEXT: } {"C/i[0]/i0"} +# CHECK-NEXT: } {"C/i[0]/i"} +# CHECK-NEXT: %subview_7 = memref.subview %subview[32, 0] [18, 32] [1, 1] : memref<50x32xf32, strided<[64, 1], offset: ?>> to memref<18x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_8 = memref.subview %subview_4[32, 0] [18, 16] [1, 1] : memref<50x16xf32, strided<[64, 1], offset: ?>> to memref<18x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c18 step %c18 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [18, 32] [1, 1] : memref<18x32xf32, strided<[64, 1], offset: ?>> to memref<18x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [18, 16] [1, 1] : memref<18x16xf32, strided<[64, 1], offset: ?>> to memref<18x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c18 step %c1 { +# CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 32] [1, 1] : memref<18x32xf32, strided<[64, 1], offset: ?>> to memref<1x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 16] [1, 1] : memref<18x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg7 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_13 = memref.subview %subview_11[0, %arg7] [1, 1] [1, 1] : memref<1x32xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_14 = memref.subview %subview_3[%arg7, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[64, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +# CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +# CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: } {"C/i[1]/k0"} +# CHECK-NEXT: } {"C/i[1]/i0"} +# CHECK-NEXT: } {"C/i[1]/i"} +# CHECK-NEXT: } {"C/j"} +# CHECK-NEXT: } {"C/k"} +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 50x64xfloat32 +# CHECK-NEXT: - %1 : 64x64xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 50x64xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [50x64xfloat32, 64x64xfloat32] -> [50x64xfloat32] +# CHECK-NEXT: +# CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_descript_slice_smaller.py b/tests/filecheck/schedules/test_descript_slice_smaller.py new file mode 100644 index 00000000..bdf11093 --- /dev/null +++ b/tests/filecheck/schedules/test_descript_slice_smaller.py @@ -0,0 +1,195 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir import Backend +from xtc.schedules.descript import descript_scheduler + +I, J, K, dtype = 50, 64, 64, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph) + +sch = impl.get_scheduler() +descript_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + 'k': {}, + 'j': {}, + 'i[:18]': { + 'i#18': {}, + 'k#32': {}, + 'j#16': {'vectorize': True}, + }, + 'i[18:]': { + 'i#32': {}, + 'k#32': {}, + 'j#16': {'vectorize': True}, + } + } +) + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_slice_first_smaller", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sch.schedule()) +evaluator = module.get_evaluator( + validate=True, +) +results, code, error = evaluator.evaluate() +print(f"CODE: {code}") +# CHECK: // -----// IR Dump Before transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<50x64xf32> {llvm.noalias}, %arg1: memref<64x64xf32> {llvm.noalias}, %arg2: memref<50x64xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<50x64xf32>) +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<50x64xf32>, memref<64x64xf32>) outs(%arg2 : memref<50x64xf32>) +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +# CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_5 "C/j" : !transform.any_op +# CHECK-NEXT: %2 = transform.structured.split %tiled_linalg_op_4 after 18 {dimension = 0 : i64} : !transform.any_op +# CHECK-NEXT: %3:2 = transform.split_handle %2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %3#0 tile_sizes [18, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i0" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %tiled_linalg_op_8 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_11 "C/i[0]/k0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +# CHECK-NEXT: %tiled_linalg_op_12, %loops_13 = transform.structured.tile_using_for %3#1 tile_sizes [32, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_13 "C/i[1]/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_14, %loops_15 = transform.structured.tile_using_for %tiled_linalg_op_12 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_15 "C/i[1]/i0" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_16, %loops_17 = transform.structured.tile_using_for %tiled_linalg_op_14 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_17 "C/i[1]/k0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_16) : (!transform.any_op) -> () +# CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %4 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %4 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: // -----// IR Dump After transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<50x64xf32> {llvm.noalias}, %arg1: memref<64x64xf32> {llvm.noalias}, %arg2: memref<50x64xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +# CHECK-NEXT: %0 = ub.poison : f32 +# CHECK-NEXT: %c18 = arith.constant 18 : index +# CHECK-NEXT: %c16 = arith.constant 16 : index +# CHECK-NEXT: %c32 = arith.constant 32 : index +# CHECK-NEXT: %c64 = arith.constant 64 : index +# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: %c0 = arith.constant 0 : index +# CHECK-NEXT: %c50 = arith.constant 50 : index +# CHECK-NEXT: %c1 = arith.constant 1 : index +# CHECK-NEXT: scf.for %arg3 = %c0 to %c50 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 64] [1, 1] : memref<50x64xf32> to memref<1x64xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c64 step %c1 { +# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x64xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[64, 1], offset: ?>>) +# CHECK-NEXT: } {"./j"} +# CHECK-NEXT: } {"./i"} +# CHECK-NEXT: scf.for %arg3 = %c0 to %c64 step %c32 { +# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [50, 32] [1, 1] : memref<50x64xf32> to memref<50x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [32, 64] [1, 1] : memref<64x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [50, 64] [1, 1] : memref<50x64xf32> to memref<50x64xf32, strided<[64, 1]>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c64 step %c16 { +# CHECK-NEXT: %subview_3 = memref.subview %subview_1[0, %arg4] [32, 16] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_4 = memref.subview %subview_2[0, %arg4] [50, 16] [1, 1] : memref<50x64xf32, strided<[64, 1]>> to memref<50x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_5 = memref.subview %subview[0, 0] [18, 32] [1, 1] : memref<50x32xf32, strided<[64, 1], offset: ?>> to memref<18x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, 0] [18, 16] [1, 1] : memref<50x16xf32, strided<[64, 1], offset: ?>> to memref<18x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c18 step %c18 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [18, 32] [1, 1] : memref<18x32xf32, strided<[64, 1], offset: ?>> to memref<18x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [18, 16] [1, 1] : memref<18x16xf32, strided<[64, 1], offset: ?>> to memref<18x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c18 step %c1 { +# CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 32] [1, 1] : memref<18x32xf32, strided<[64, 1], offset: ?>> to memref<1x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 16] [1, 1] : memref<18x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg7 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_13 = memref.subview %subview_11[0, %arg7] [1, 1] [1, 1] : memref<1x32xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_14 = memref.subview %subview_3[%arg7, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[64, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +# CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +# CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: } {"C/i[0]/k0"} +# CHECK-NEXT: } {"C/i[0]/i0"} +# CHECK-NEXT: } {"C/i[0]/i"} +# CHECK-NEXT: %subview_7 = memref.subview %subview[18, 0] [32, 32] [1, 1] : memref<50x32xf32, strided<[64, 1], offset: ?>> to memref<32x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_8 = memref.subview %subview_4[18, 0] [32, 16] [1, 1] : memref<50x16xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c32 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [32, 32] [1, 1] : memref<32x32xf32, strided<[64, 1], offset: ?>> to memref<32x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [32, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<32x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 32] [1, 1] : memref<32x32xf32, strided<[64, 1], offset: ?>> to memref<1x32xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg7 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_13 = memref.subview %subview_11[0, %arg7] [1, 1] [1, 1] : memref<1x32xf32, strided<[64, 1], offset: ?>> to memref<1x1xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %subview_14 = memref.subview %subview_3[%arg7, 0] [1, 16] [1, 1] : memref<32x16xf32, strided<[64, 1], offset: ?>> to memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[64, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[64, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +# CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +# CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[64, 1], offset: ?>> +# CHECK-NEXT: } {"C/i[1]/k0"} +# CHECK-NEXT: } {"C/i[1]/i0"} +# CHECK-NEXT: } {"C/i[1]/i"} +# CHECK-NEXT: } {"C/j"} +# CHECK-NEXT: } {"C/k"} +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 50x64xfloat32 +# CHECK-NEXT: - %1 : 64x64xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 50x64xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [50x64xfloat32, 64x64xfloat32] -> [50x64xfloat32] +# CHECK-NEXT: +# CHECK-NEXT: CODE: 0