From 867caedcb498b2c3b760b15c4221dd914057a806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Thu, 9 Oct 2025 10:42:25 +0200 Subject: [PATCH 01/23] Descript extension --- .gitmodules | 3 + src/sampler | 1 + src/xtc/__init__.py | 5 + src/xtc/cli/mlir_loop.py | 77 ++- src/xtc/schedules/descript_extend.py | 452 ++++++++++++++++++ src/xtc/search/strategies.py | 132 ++++- .../splitting/v_splitting_extend.mlir | 39 ++ ...test_matmul_descript_extend_mlir_sample.py | 169 +++++++ .../test_matmul_descript_extend_mlir_split.py | 215 +++++++++ ...atmul_descript_extend_mlir_split_sample.py | 212 ++++++++ .../test_matmul_descript_extend_tvm_goto.py | 162 +++++++ ...est_matmul_descript_extend_tvm_strategy.py | 172 +++++++ .../search/test_matmul_descript_3axes.py | 139 ++++++ .../search/test_matmul_descript_goto.py | 147 ++++++ .../search/test_matmul_descript_simple.py | 137 ++++++ .../search/test_matmul_descript_split.py | 146 ++++++ 16 files changed, 2199 insertions(+), 9 deletions(-) create mode 100644 .gitmodules create mode 160000 src/sampler create mode 100644 src/xtc/schedules/descript_extend.py create mode 100644 tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir create mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py create mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py create mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py create mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py create mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py create mode 100644 tests/filecheck/search/test_matmul_descript_3axes.py create mode 100644 tests/filecheck/search/test_matmul_descript_goto.py create mode 100644 tests/filecheck/search/test_matmul_descript_simple.py create mode 100644 tests/filecheck/search/test_matmul_descript_split.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..646372304 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/sampler"] + path = src/sampler + url = git@gitlab.inria.fr:CORSE/sampler.git diff --git a/src/sampler b/src/sampler new file mode 160000 index 000000000..896a6108c --- /dev/null +++ b/src/sampler @@ -0,0 +1 @@ +Subproject commit 896a6108c6c62147c1c57ffa36a942cdd4e11ad0 diff --git a/src/xtc/__init__.py b/src/xtc/__init__.py index 0c0fdf80b..118e14ccb 100644 --- a/src/xtc/__init__.py +++ b/src/xtc/__init__.py @@ -3,5 +3,10 @@ # Copyright (c) 2024-2026 The XTC Project Authors # import importlib.metadata +import os +import sys __version__ = importlib.metadata.version("xtc") +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../sampler")) +) diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index f28b8243e..fbc8746a2 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -13,6 +13,7 @@ from xtc.itf.schd.scheduler import Scheduler from xtc.schedules.descript import descript_scheduler +from xtc.schedules.descript_extend import descript_extend_scheduler from xtc.utils.xdsl_aux import parse_xdsl_module from xtc.backends.mlir.MlirNodeBackend import MlirNodeBackend from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend @@ -48,7 +49,8 @@ def main(): node_name, always_vectorize=args.always_vectorize, concluding_passes=args.concluding_passes, - no_alias=not args.alias, + no_alias=args.no_alias, + extend=args.extend, ) schedulers.append(sched) @@ -116,6 +118,7 @@ def build_node_scheduler( always_vectorize: bool, concluding_passes: list[str], no_alias: bool, + extend: bool, ) -> Scheduler: backend = build_mlir_node_backend( op=op, @@ -131,18 +134,70 @@ def build_node_scheduler( if "loop.schedule" in op.attributes: schedule_attribute = op.attributes.get("loop.schedule") assert isinstance(schedule_attribute, builtin.DictionaryAttr) - normal_schedule = normalize_schedule(schedule_attribute) - descript_scheduler( - scheduler=scheduler, - node_name=node_name, - abstract_axis=scheduler.backend.dims, - spec=normal_schedule, - ) + if extend: + normal_schedule = normalize_extend_schedule(schedule_attribute) + descript_extend_scheduler( + scheduler=scheduler, + node_name=node_name, + abstract_axis=scheduler.backend.dims, + spec=normal_schedule, + ) + else: + normal_schedule = normalize_schedule(schedule_attribute) + descript_scheduler( + scheduler=scheduler, + node_name=node_name, + abstract_axis=scheduler.backend.dims, + spec=normal_schedule, + ) op.attributes.pop("loop.schedule", None) return scheduler +def normalize_extend_schedule( + raw_schedule: builtin.DictionaryAttr, +) -> dict[str, dict]: + schedule: dict[str, Any] = {} + for declaration_, val_ in raw_schedule.data.items(): + assert isinstance(declaration_, str) + assert isinstance(val_, builtin.DictionaryAttr) + sub_schedule: dict[str, Any] = {} + for declaration, val in val_.data.items(): + assert isinstance(declaration, str) + if ":" in declaration: + if not isinstance(val, builtin.DictionaryAttr): + raise Exception( + f"The schedule within a split should be a dictionnary or void but got {declaration}" + ) + + assert isinstance(val, builtin.DictionaryAttr) + inner_schedule = normalize_extend_schedule(val) + sub_schedule[str(declaration)] = inner_schedule + else: + annotations: dict[str, int | None] = {} + if isinstance(val, builtin.DictionaryAttr): + for instr, param in val.data.items(): + assert isinstance(instr, str) + if isinstance(param, builtin.UnitAttr): + annotations[instr] = None + elif isinstance(param, builtin.IntegerAttr): + annotations[instr] = param.value.data + else: + raise Exception( + "Annotation parameter should be void or int." + ) + + elif not isinstance(val, builtin.UnitAttr): + raise Exception( + f"Annotation parameter should be a dict or void but got {type(val)}" + ) + + sub_schedule[declaration] = annotations + schedule[declaration_] = sub_schedule + return schedule + + def normalize_schedule( raw_schedule: builtin.DictionaryAttr, ) -> dict[str, dict]: @@ -328,6 +383,12 @@ def parse_args() -> argparse.Namespace: default=False, help="Print debug messages.", ) + parser.add_argument( + "--extend", + action="store_true", + default=False, + help="Use descript_extend instead of default", + ) args = parser.parse_args() diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py new file mode 100644 index 000000000..ab4371c8f --- /dev/null +++ b/src/xtc/schedules/descript_extend.py @@ -0,0 +1,452 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from typing import Any, Tuple +from dataclasses import dataclass +import re +from typing_extensions import override + +from xtc.itf.schd.scheduler import Scheduler + +from xtc.schedules.descript import Descript, SchedDict + + +def descript_extend_scheduler( + scheduler: Scheduler, + node_name: str, + abstract_axis: list[str], + spec: dict[str, dict], + sample: dict[str, Any] = {}, +): + descript = DescriptExtend(abstract_axis=abstract_axis) + descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) + + +@dataclass(frozen=True) +class DescriptExtend(Descript): + @override + def apply( + self, + node_name: str, + spec: dict[str, dict], + scheduler: Scheduler, + sample: dict[str, Any] = {}, + ): + flat_schedules = self._flatten_schedule(root=node_name, spec=spec, head=[]) + variables = set() + constraints = set() + for schedule in flat_schedules: + variables.update(schedule["variables"]) + constraints.update(schedule["constraints"]) + + flat_schedules = self.apply_sample(flat_schedules, sample) + self.apply_scheduler(flat_schedules, scheduler) + + def flatten_schedule(self, node_name: str, spec: dict[str, dict]): + flat_schedules = self._flatten_schedule(root=node_name, spec=spec, head=[]) + variables = [] + constraints = [] + axes = {} + orders = {} + for schedule in flat_schedules: + variables += schedule["variables"] + constraints += schedule["constraints"] + for axis, order in schedule["axes"].items(): + axes[f"order_{axis}"] = order + axis_orders = schedule["axis_orders"] + for axis in axis_orders: + orders[axis] = schedule["axes"][axis] + variables = list(dict.fromkeys(variables)) + constraints = list(dict.fromkeys(constraints)) + return (flat_schedules, variables, constraints, axes, orders) + + def apply_sample( + self, flat_schedules: list[SchedDict], sample: dict[str, Any] + ) -> list[SchedDict]: + flat_schedules = flat_schedules.copy() + for schedule in flat_schedules: + for k in ["splits", "tiles"]: + for d, s in schedule[k].items(): + for d_, s_ in s.items(): + if isinstance(s_, str): + schedule[k][d][d_] = sample[s_] + for k in ["vectorize", "parallelize"]: + for i, s in enumerate(schedule[k]): + if isinstance(s, Tuple): + s, loop = s + s = sample.get(s, False) + if s is None or s: + schedule[k][i] = loop + else: + schedule[k].pop(i) + for d, s in schedule["unroll"].items(): + if isinstance(s, str): + val = sample[s] + if val is None: + for s__ in schedule["tiles"].values(): + for d_, s_ in s__.items(): + if d == d_: + val = s_ + break + if val is not None: + break + schedule["unroll"][d] = val + for d, axes in schedule["axes"].items(): + d_holder = f"order_{d}" + s = sample.get(d_holder, None) + if s: + sch = {} + for a in s: + sch[a] = axes[a] + schedule["axes"][d] = sch + return flat_schedules + + def apply_scheduler(self, flat_schedules: list[SchedDict], scheduler: Scheduler): + self._check_flattened_schedule(flat_schedules) + for schedule in flat_schedules: + root = schedule["root"] + interchange = [] + + for d, s in schedule["axes"].items(): + s = list(s.values()) + for s in s: + interchange += s + + p = schedule["packs"].get(d, None) + if p: + for _, input, pad in p: + scheduler.pack_at(s[-1], input, pad=pad) + + for d, s in schedule["splits"].items(): + scheduler.split(d, s, root=root) + + for d, s in schedule["tiles"].items(): + scheduler.tile(d, s, root=root) + + # print(interchange) + scheduler.interchange(interchange, root=root) + scheduler.vectorize(schedule["vectorize"], root=root) + scheduler.parallelize(schedule["parallelize"], root=root) + scheduler.unroll(schedule["unroll"], root=root) + + @override + def _flatten_schedule( + self, + root: str, + spec: dict[str, dict], + head: list[str], + tile_sizes: dict[str, int | str] | None = None, + ) -> list[SchedDict]: + recursive_scheds: list[SchedDict] = [] + sched: SchedDict = { + "root": root, + "fusions": {}, + "packs": {}, + "axis_orders": [], + "axes": {}, + "splits": {}, + "tiles": {a: {} for a in self.abstract_axis}, + "interchange": [], + "vectorize": [], + "parallelize": [], + "unroll": {}, + "variables": [], + "constraints": [], + } + # State of the schedule + if tile_sizes: + axes_sizes: dict[str, int | str] = tile_sizes + else: + axes_sizes = {a: f"[{a}]" for a in self.abstract_axis} + sizes: dict[str, int | str | None] = {} + previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} + interchange: list[str] = head + constraints: list[str] = [] + variables: list[str] = [] + # Processing the schedule + for tree_declaration, tree_val in spec.items(): + assert isinstance(tree_val, dict) + tree_interchange = {} + tree_packs = [] + tree_fusion = [] + for declaration, val in tree_val.items(): + if declaration == "fusion": + # sched["fusions"][tree_declaration] = val + tree_fusion.append(val) + continue + elif declaration == "pack": + for val_ in val: + if len(val_) != 3: + raise Exception(f"Packing {val_} should have 3 parameters.") + param, input, pack = val_ + tree_packs.append((param, input, pack)) + if isinstance(param, str): + variables.append(param) + constraints.append(f"0 <= {param} <= 1") + if isinstance(input, str): + raise Exception("Packing input cannot be a variable.") + if isinstance(pack, str): + variables.append(pack) + constraints.append(f"0 <= {pack} <= 1") + continue + elif declaration == "explore_axis_order": + sched["axis_orders"].append(tree_declaration) + continue + elif ":" in declaration: + axis_name, x, y = self.parse_split_declaration(declaration) + self._check_axis_existence(axis_name) + + # The only declaration where y (the cut) is None is the + # last one, so it cannot be the previous one. + cut = previous_cut[axis_name] + + # When x (the starting point of the slice), is not + # specified, it is the previous cut + if x is None: + x = cut + + # print(declaration, axis_name, cut, x, y) + lam, inner_size = self._extended_check_splitting_intervals( + declaration, axis_name, cut, x, y + ) + current_size = axes_sizes[axis_name] + # Update the previous cut + previous_cut[axis_name] = y + # Save the cutting points of the new dimensions + if axis_name not in sched["splits"]: + sched["splits"][axis_name] = {} + new_dim_index = len(sched["splits"][axis_name]) + new_dim_name = f"{axis_name}[{new_dim_index}]" + new_axes_root_name = f"{root}/{new_dim_name}" + sched["splits"][axis_name][new_dim_name] = x + if axis_name in tree_interchange: + tree_interchange[axis_name].append(new_dim_name) + else: + tree_interchange[axis_name] = [new_dim_name] + inner_size = inner_size if inner_size else f"{current_size} - {x}" + inner_size_holder = f"{axis_name}_{new_dim_index}_" + constraints.append(f"{inner_size_holder} == {inner_size}") + axes_sizes[axis_name] = inner_size_holder + + if lam: + if isinstance(y, str): + variables.append(y) + constraints.append(lam) + constraints.append(f"1 || {y} || {current_size}") + + # Fetch the schedule associated with the new dimension + next_schedule = val + assert isinstance(next_schedule, dict) + inner_scheds = self._flatten_schedule( + spec=next_schedule, + root=new_axes_root_name, + tile_sizes=axes_sizes.copy(), + head=[axis_name], + ) + axes_sizes[axis_name] = current_size + recursive_scheds += inner_scheds + continue + elif "#" in declaration: + axis_name, tile_size = declaration.split("#") + self._check_axis_existence(axis_name) + assert isinstance(tile_size, str) + if tile_size.isdecimal(): + loop_size = int(tile_size) + else: + loop_size = tile_size + variables.append(tile_size) + constraints.append( + f"1 || {tile_size} || {axes_sizes[axis_name]}" + ) + if not loop_size: + raise Exception( + f"Invalid tile size: '{tile_size}' in {declaration}" + ) + + axes_sizes[axis_name] = loop_size + tile_num = len(sched["tiles"][axis_name]) + loop_name = f"{axis_name}{tile_num}" + sched["tiles"][axis_name][loop_name] = loop_size + sizes[loop_name] = loop_size + if axis_name in tree_interchange: + raise Exception( + f"axis {axis_name} already is used in level {tree_declaration}." + ) + tree_interchange[axis_name] = [loop_name] + elif declaration in self.abstract_axis: + loop_name = declaration + if loop_name in tree_interchange: + raise Exception( + f""" + Axis {declaration} is scheduled twice (or more). + """ + ) + tree_interchange[loop_name] = [loop_name] + else: + raise Exception( + f""" + Axis {declaration} is not a defined axis. + Known axis are: {self.abstract_axis}") + """ + ) + + self.annotate( + loop_name=loop_name, + sizes=sizes, + annotations=val, + sched=sched, + ) + sched["axes"][tree_declaration] = tree_interchange + if len(tree_packs) > 0: + sched["packs"][tree_declaration] = tree_packs + if len(tree_fusion) > 0: + sched["fusions"][tree_declaration] = tree_fusion + for v in tree_interchange.values(): + interchange += v + + # Check if the last cut of each axis is either 0 or None. + # None correspond to "until the end of the loop". 0 is the + # default value, if it has 0 then it means the axis isn't splitted. + # Any other value means the split is let in a partial state. + for axis, cut in previous_cut.items(): + if cut is not None and cut != 0: + raise Exception( + f"Splitting on axis {axis} should end but stops at {cut}" + ) + + sched["interchange"] = interchange + sched["variables"] = variables + sched["variables"] + sched["constraints"] = constraints + sched["constraints"] + return [sched] + recursive_scheds + + def _extended_check_splitting_intervals( + self, + declaration: str, + axis_name: str, + cut: int | str | None, + x: int | str | None, + y: int | str | None, + ) -> Tuple[str | None, int | str | None]: + if cut is None: + raise Exception( + f""" + {declaration} is defined on an already covered axis. + This might be caused by a missing endpoint: {axis_name} + """ + ) + + assert isinstance(x, int | str) + + if isinstance(cut, int) and isinstance(x, int): + if x > cut: + raise Exception( + f""" + Splitting doesn't cover the whole axis + (jumps from {cut} to {x} on axis {axis_name}) + """ + ) + elif x < cut: + raise Exception( + f""" + Splitting are overlapping on axis {axis_name} + (covered until {cut} but restart at {x}) + """ + ) + else: + if x != cut: + raise Exception( + f""" + Splitting should use the same variables between an end and a start + ({cut} and {x} on axis {axis_name}) + """ + ) + + if y is None: + return (None, None) + + constraint = f"{x} < {y}" + if isinstance(y, int): + if isinstance(x, int): + if x >= y: + raise Exception( + f""" + Starting point in the splitting cannot be greater or equal to + the ending point in: {declaration} + """ + ) + else: + return (None, y - x) + else: + return (constraint, f"{y} - {x}") + if isinstance(x, int) and x == 0: + return (constraint, f"{y}") + return (constraint, f"{y} - {x}") + + def annotate( + self, + loop_name: str, + sizes: dict[str, int | str | None], + annotations: dict[str, Any], + sched: dict[str, Any], + ): + for instr, param in annotations.items(): + assert isinstance(instr, str) + match instr: + case "unroll": + if param is None and loop_name in sizes: + ufactor = sizes[loop_name] + else: + ufactor = param + if isinstance(param, str): + sched["variables"].append(param) + sched["constraints"].append( + f"1 || {param} || {sizes[loop_name]}" + ) + sched["unroll"][loop_name] = ufactor + + case "vectorize": + if isinstance(param, str): + sched["variables"].append(param) + sched["constraints"].append(f"0 <= {param} <= 1") + sched["vectorize"].append((param, loop_name)) + continue + if param is None: + sched["vectorize"].append(loop_name) + continue + raise Exception( + "Vectorize should not have a parameter (Feature not implemented)" + ) + + case "parallelize": + if isinstance(param, str): + sched["variables"].append(param) + sched["constraints"].append(f"0 <= {param} <= 1") + sched["parallelize"].append((param, loop_name)) + continue + if param is None: + sched["parallelize"].append(loop_name) + continue + if param is not None: + raise Exception( + "Parallelize should not have a parameter (Feature not implemented)" + ) + + case _: + raise Exception(f"Unknown annotation on {loop_name}: {instr}") + + def parse_split_declaration( + self, + declaration: str, + ) -> Tuple[str, int | str | None, int | str | None]: + pattern = r"^(.*)\[(?:(-\w+|\w*)?):(?:(-\w+|\w*)?)\]$" + match = re.match(pattern, declaration) + if not match: + raise Exception(f"Wrong format {declaration}") + + prefix, x_str, y_str = match.groups() + x = int(x_str) if x_str.isnumeric() else x_str + y = int(y_str) if y_str.isnumeric() else y_str + x = x if x else None + y = y if y else None + return prefix, x, y diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 72d10bf9d..55e2bb5d1 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -9,9 +9,14 @@ import itertools import numpy as np +from properties import constraints_from_str, hypergraph +from properties import variables as sampler_variables +from strategy import execute_dynamic, execute_static, solve_with_z3 from xtc.itf.graph import Graph from xtc.itf.schd import Scheduler +from xtc.itf.schd.scheduler import DEFAULT_ROOT from xtc.itf.search import Sample, Strategy +from xtc.schedules.descript_extend import DescriptExtend from xtc.utils.math import ( factors_to_sizes, factors_enumeration, @@ -20,7 +25,6 @@ sample_uniques, ) - __all__ = [ "Strategies", ] @@ -941,6 +945,132 @@ def _filter(self, samples: Iterator[VecSample]) -> Iterator[VecSample]: yield x +class Strategy_Descript(Strategy): + def __init__( + self, + graph: Graph, + spec: dict[str, dict], + constraints: list[str] = [], + vec_size: int = 16, + max_unroll: int = 256, + threads: int = 1, + max_parallelize: int = 1, + **kwargs: Any, + ) -> None: + self._graph = graph + self._vec_size = vec_size + self._max_unroll = max_unroll + self._threads = threads + # Schedule output operation + self._op = graph.outputs_nodes[0].operation + self._stats: dict[str, int] = {} + self._parallelize = self._threads > 1 + self._max_parallelize = max_parallelize + self._vectorize = self._vec_size > 1 + self._unroll = self._max_unroll != 0 + # TODO: should go into some machine description + self._arch_vreg_num = kwargs.get("vreg_num", 32) + self._arch_l1_size = kwargs.get("l1_size", 32 * 1024) + self._arch_l2_size = kwargs.get("l2_size", 1024 * 1024) + self._axes = list(self._op.dims) + self._sizes = self._constant_sizes() + descript = DescriptExtend(abstract_axis=self._axes) + self._descript = descript + input_constraints = constraints + self._flat_schedules, self._sample_names, constraints, axes, orders = ( + descript.flatten_schedule(node_name=DEFAULT_ROOT, spec=spec) + ) + for a, v in self._sizes.items(): + for i, s in enumerate(constraints): + assert isinstance(s, str) + constraints[i] = s.replace(f"[{a}]", str(v)) + self._axes_names = {} + for a, v in axes.items(): + self._axes_names[a] = v + self._orders = {} + order_constraints = [] + for a, v in orders.items(): + permutation = list(itertools.permutations(v)) + a_holder = f"order_{a}" + self._orders[a_holder] = permutation + order_constraints.append(f"0 <= {a_holder} <= {len(permutation) - 1}") + constraints = constraints + input_constraints + order_constraints + # print(constraints) + constraints = constraints_from_str(constraints, silent=True) + # print(constraints) + properties, constraints = hypergraph(constraints, silent=True) + # print(properties, constraints) + methods = solve_with_z3( + sampler_variables.keys(), properties, constraints, silent=True + ) + # print(methods) + enumerations = execute_static(methods, properties, constraints, silent=True) + # print(enumerations) + self._properties = properties + self._constraints = constraints + self._methods = methods + self._enumerations = enumerations + + @property + @override + def graph(self) -> Graph: + return self._graph + + @override + def generate(self, scheduler: Scheduler, sample: Sample) -> None: + descript = self._descript + flat_schedules = self._flat_schedules + for a, p in self._orders.items(): + if a in sample: + if isinstance(sample[a], int): + sample[a] = p[sample[a]] + flat_schedules = descript.apply_sample( + flat_schedules=flat_schedules, sample=sample + ) + descript.apply_scheduler(flat_schedules, scheduler) + + @override + def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: + draw = execute_dynamic( + self._methods, + self._properties, + self._constraints, + self._enumerations, + k=num, + silent=True, + ) + # print(list(draw.values())[0][0]) + return iter(list(draw.values())[0]) + + @override + def exhaustive(self) -> Iterator[Sample]: + return self.sample(1000) + + @override + def default_schedule(self, opt_level: int = 2) -> Sample: + while True: + for x in self.sample(1): + if x: + return x + + def _constant_sizes(self) -> Mapping[str, int]: + sizes = {a: v for a, v in self._op.dims.items() if isinstance(v, int)} + return sizes + + @override + def dict_to_sample(self, sample: dict[str, Any]) -> Sample: + return sample + + @override + def sample_to_dict(self, sample: Sample) -> dict[str, int]: + return sample + + @property + @override + def sample_names(self) -> list[str]: + return self._sample_names + + class Strategies: @classmethod def names(cls) -> Sequence[str]: diff --git a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir new file mode 100644 index 000000000..3a644a0b3 --- /dev/null +++ b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-loop --extend --no-alias --print-source-ir %s 2>&1 | filecheck %s + +func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<256x256xf64>){ + linalg.matmul { + loop.dims = ["i", "j"], + loop.schedule = { + "One" = { + "i[:5]" = { "Two" = {"j"} }, + "i[5:]" = { "Two" = {"j"} }, + "fusion" + } + } + } + ins(%A, %B : memref<256x512xf64>, memref<512x256xf64>) + outs(%C: memref<256x256xf64>) + return +}// CHECK: // -----// IR Dump Before transform //----- // +// CHECK-NEXT: module attributes {transform.with_named_sequence} { +// CHECK-NEXT: func.func @matmul(%arg0: memref<256x512xf64> {llvm.noalias}, %arg1: memref<512x256xf64> {llvm.noalias}, %arg2: memref<256x256xf64> {llvm.noalias}) { +// CHECK-NEXT: linalg.matmul {__node0__} ins(%arg0, %arg1 : memref<256x512xf64>, memref<512x256xf64>) outs(%arg2 : memref<256x256xf64>) +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +// CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +// CHECK-NEXT: %0 = transform.structured.match attributes {__node0__} in %arg0 : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: %first, %second = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %first tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops "__node0__/i[0]/j" : !transform.any_op +// CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %second tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_1 "__node0__/i[1]/j" : !transform.any_op +// CHECK-NEXT: %2 = transform.get_parent_op %loops_1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: %3 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py new file mode 100644 index 000000000..7240455cf --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py @@ -0,0 +1,169 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + "DDR": { + "k": {}, + "i": {}, + "j": {}, + }, + "R": { + "i#i_inner": {"unroll": "i_unroll"}, + "j#j_inner": {"vectorize": "j_vectorize"}, + }, + }, + sample={"i_inner": 2, "j_inner": 16, "i_unroll": None, "j_vectorize": None}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_mlir_sample", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +# CHECK: // -----// IR Dump Before transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () +# CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: // -----// IR Dump After transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +# CHECK-NEXT: %c16 = arith.constant 16 : index +# CHECK-NEXT: %c2 = arith.constant 2 : index +# CHECK-NEXT: %c512 = arith.constant 512 : index +# CHECK-NEXT: %c32 = arith.constant 32 : index +# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: %c0 = arith.constant 0 : index +# CHECK-NEXT: %c4 = arith.constant 4 : index +# CHECK-NEXT: %c1 = arith.constant 1 : index +# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +# CHECK-NEXT: } {"./j"} +# CHECK-NEXT: } {"./i"} +# CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { +# CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { +# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %c2_7 = arith.constant 2 : index +# CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> +# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> +# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %c1_10 = arith.constant 1 : index +# CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index +# CHECK-NEXT: %10 = arith.addi %c0, %9 : index +# CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> +# CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> +# CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: } {"C/j"} +# CHECK-NEXT: } {"C/i"} +# CHECK-NEXT: } {"C/k"} +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 4x512xfloat32 +# CHECK-NEXT: - %1 : 512x32xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 4x32xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +# CHECK-NEXT: +# CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py new file mode 100644 index 000000000..21cd7a391 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py @@ -0,0 +1,215 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 16, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + "DDR": { + "j": {}, + "k": {}, + }, + "L2": { + "j#jDDR": {}, + "i[:iT1]": { + "R": { + "i#iR1": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + }, + "i[iT1:]": { + "R": { + "i#iR2": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + }, + }, + }, + sample={"jDDR": 16, "jR": 4, "iR1": 2, "iR2": 4, "iT1": 4}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_mlir_split", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +# CHECK: // -----// IR Dump Before transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op +# CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_4 after 8 {dimension = 0 : i64} : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %first tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i0" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/j0" : !transform.any_op +# CHECK-NEXT: %3 = transform.get_parent_op %loops_7 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +# CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %5 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %5 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op +# CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %7 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %7 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: // -----// IR Dump After transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +# CHECK-NEXT: %c4 = arith.constant 4 : index +# CHECK-NEXT: %c2 = arith.constant 2 : index +# CHECK-NEXT: %c8 = arith.constant 8 : index +# CHECK-NEXT: %c512 = arith.constant 512 : index +# CHECK-NEXT: %c32 = arith.constant 32 : index +# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: %c0 = arith.constant 0 : index +# CHECK-NEXT: %c16 = arith.constant 16 : index +# CHECK-NEXT: %c1 = arith.constant 1 : index +# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +# CHECK-NEXT: } {"./j"} +# CHECK-NEXT: } {"./i"} +# CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { +# CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> +# CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { +# CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_5 = memref.subview %subview_3[0, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c4 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_9, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_10 : memref<2x16xf32, strided<[32, 1], offset: ?>>) +# CHECK-NEXT: } {"C/i[0]/j0"} +# CHECK-NEXT: %0 = arith.addi %arg5, %c2 : index +# CHECK-NEXT: %subview_11 = memref.subview %subview_5[%0, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%0, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_11, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_12 : memref<2x16xf32, strided<[32, 1], offset: ?>>) +# CHECK-NEXT: } {"C/i[0]/j0"} +# CHECK-NEXT: } {"C/i[0]/i0"} +# CHECK-NEXT: %subview_7 = memref.subview %subview_3[8, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_8 = memref.subview %subview_2[8, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c2 { +# CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %0 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %1 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_10[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> +# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> +# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %8, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %9 = arith.addi %arg5, %c1 : index +# CHECK-NEXT: %subview_11 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_8[%9, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %10 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %11 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %12 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> +# CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> +# CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %18, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: } {"C/i[1]/i0"} +# CHECK-NEXT: } {"C/k"} +# CHECK-NEXT: } {"C/j"} +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 16x512xfloat32 +# CHECK-NEXT: - %1 : 512x32xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 16x32xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] +# CHECK-NEXT: +# CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py new file mode 100644 index 000000000..db11dfa29 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py @@ -0,0 +1,212 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 16, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + "DDR": { + "j": {}, + "k": {}, + "i[:i_split]": { + "Rr": { + "i#2": {"unroll": None}, + "j#16": {"vectorize": None}, + }, + }, + "i[i_split:]": { + "Rl": { + "i#2": {"unroll": None}, + "j#16": {"vectorize": None}, + }, + }, + }, + }, + sample={"i_split": 8}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_mlir_split_sample", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK: // -----// IR Dump Before transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op +#CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_4 after 8 {dimension = 0 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %first tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i0" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/j0" : !transform.any_op +#CHECK-NEXT: %3 = transform.get_parent_op %loops_7 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +#CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %5 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %5 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: // -----// IR Dump After transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +#CHECK-NEXT: %c4 = arith.constant 4 : index +#CHECK-NEXT: %c2 = arith.constant 2 : index +#CHECK-NEXT: %c8 = arith.constant 8 : index +#CHECK-NEXT: %c512 = arith.constant 512 : index +#CHECK-NEXT: %c32 = arith.constant 32 : index +#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: %c0 = arith.constant 0 : index +#CHECK-NEXT: %c16 = arith.constant 16 : index +#CHECK-NEXT: %c1 = arith.constant 1 : index +#CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"./j"} +#CHECK-NEXT: } {"./i"} +#CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { +#CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> +#CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { +#CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_5 = memref.subview %subview_3[0, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c4 { +#CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_9, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_10 : memref<2x16xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"C/i[0]/j0"} +#CHECK-NEXT: %0 = arith.addi %arg5, %c2 : index +#CHECK-NEXT: %subview_11 = memref.subview %subview_5[%0, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_6[%0, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_11, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_12 : memref<2x16xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"C/i[0]/j0"} +#CHECK-NEXT: } {"C/i[0]/i0"} +#CHECK-NEXT: %subview_7 = memref.subview %subview_3[8, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_8 = memref.subview %subview_2[8, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c2 { +#CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %0 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %1 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_10[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> +#CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> +#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %8, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %9 = arith.addi %arg5, %c1 : index +#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%9, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[1]/i0"} +#CHECK-NEXT: } {"C/k"} +#CHECK-NEXT: } {"C/j"} +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 16x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 16x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] +#CHECK-NEXT: +#CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py new file mode 100644 index 000000000..2c6b57231 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py @@ -0,0 +1,162 @@ +# RUN: python %s 2>&1 | filecheck %s +# REQUIRES: module_tvm + +import xtc.graphs.xtc.op as O +from xtc.backends.tvm import Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 512, 512, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + "DDR": { + "j": {"parallelize": "par"}, + "k": {}, + "i": {}, + # "explore_axis_order": True, + "pack": [("pack_B", 1, True), ("pack_A", 0, True)], + }, + # "DDRk": { + # }, + # "DDRi": { + # }, + "L3": { + "j#jL3": {}, + }, + "L2": { + "i#iL2": {}, + }, + "L1": { + "k#kL1": {"unroll": "k_unroll"}, + }, + "R": { + "i#iR": {"unroll": None}, + "j#jR": {"vectorize": None} + }, + }, + sample={"par": None, "jL3": 36, "iL2": 128, "kL1": 16, "k_unroll": 2, "iR": 2, "jR": 6, "pack_B": None, "pack_A": None}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_tvm_goto", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 512x512xfloat32 +#CHECK-NEXT: - %1 : 512x512xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 512x512xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] +#CHECK-NEXT: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-NEXT: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for i, j in T.grid(512, 512): +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) +#CHECK-NEXT: for k in range(512): +#CHECK-NEXT: cse_var_2: T.int32 = i * 512 +#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] +#CHECK-NEXT:INPS = list(obj.values())[:-1] +#CHECK-NEXT:O = obj['C'] +#CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) +#CHECK-NEXT:i, j, = O.op.axis +#CHECK-NEXT:k, = O.op.reduce_axis +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) +#CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) +#CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) +#CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) +#CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) +#CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) +#CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1) +#CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) +#CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) +#CHECK-NEXT:sch[O].unroll(__u_k0) +#CHECK-NEXT:sch[O].unroll(i1) +#CHECK-NEXT:sch[O].vectorize(j1) +#CHECK-NEXT:sch[O].parallel(j) +#CHECK-NEXT: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-NEXT: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for j_outer in T.parallel(15): +#CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): +#CHECK-NEXT: for j_inner_inner_init_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s] = T.float32(0.0) +#CHECK-NEXT: for j_inner_inner_init_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s + 512] = T.float32(0.0) +#CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): +#CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") +#CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] +#CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: for j_inner_inner_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +#CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s] +#CHECK-NEXT: for j_inner_inner_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s + 512 +#CHECK-NEXT: C_1[cse_var_4] = C_1[cse_var_4] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s] +#CHECK-NEXT: for j_inner_inner_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +#CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s +#CHECK-NEXT: C_1[cse_var_7] = C_1[cse_var_7] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s + 512] +#CHECK-NEXT: for j_inner_inner_s in range(6): +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512 +#CHECK-NEXT: C_1[cse_var_10] = C_1[cse_var_10] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512] +#CHECK-NEXT:CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py new file mode 100644 index 000000000..afbb47615 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py @@ -0,0 +1,172 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.tvm import Backend +# from xtc.itf.search import strategy +# from xtc.schedules.descript import descript_tree_scheduler +from xtc.search.strategies import Strategy_Descript as Strategy + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() + +spec = { + "DDR": { + "k": {}, + "i": {}, + "j": {}, + }, + "R": { + "i#i_inner": {"unroll": "i_unroll"}, + "j#j_inner": {"vectorize": "j_vectorize"}, + }, +} + +sample = {"i_inner": 2, "j_inner": 16, "i_unroll": 2, "j_vectorize": None, "order_R": ["j", "i"]} + +strategy = Strategy(graph, spec) + +strategy.generate(sch, sample) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_tvm_strategy", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +# CHECK: // -----// IR Dump Before transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) +# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op +# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () +# CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: transform.apply_patterns to %3 { +# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +# CHECK-NEXT: } : !transform.any_op +# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op +# CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +# CHECK-NEXT: transform.yield +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: // -----// IR Dump After transform //----- // +# CHECK-NEXT: module attributes {transform.with_named_sequence} { +# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +# CHECK-NEXT: %c16 = arith.constant 16 : index +# CHECK-NEXT: %c2 = arith.constant 2 : index +# CHECK-NEXT: %c512 = arith.constant 512 : index +# CHECK-NEXT: %c32 = arith.constant 32 : index +# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +# CHECK-NEXT: %c0 = arith.constant 0 : index +# CHECK-NEXT: %c4 = arith.constant 4 : index +# CHECK-NEXT: %c1 = arith.constant 1 : index +# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +# CHECK-NEXT: } {"./j"} +# CHECK-NEXT: } {"./i"} +# CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { +# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> +# CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { +# CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { +# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %c2_7 = arith.constant 2 : index +# CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> +# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> +# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %c1_10 = arith.constant 1 : index +# CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index +# CHECK-NEXT: %10 = arith.addi %c0, %9 : index +# CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +# CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +# CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> +# CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> +# CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> +# CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> +# CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> +# CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +# CHECK-NEXT: } {"C/j"} +# CHECK-NEXT: } {"C/i"} +# CHECK-NEXT: } {"C/k"} +# CHECK-NEXT: return +# CHECK-NEXT: } +# CHECK-NEXT: } +# CHECK-NEXT: +# CHECK-NEXT: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 4x512xfloat32 +# CHECK-NEXT: - %1 : 512x32xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 4x32xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +# CHECK-NEXT: +# CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py new file mode 100644 index 000000000..1df85036c --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -0,0 +1,139 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph, backend="tvm") +spec = { + "DDR": { + "j": {}, + "k": {}, + "i": {}, + "explore_axis_order": None, + }, + "R": { + "j#jR": {}, + "k#kR": {}, + "i#iR": {}, + "explore_axis_order": None, + }, +} +strategy = Strategy(graph, spec, max_unroll=8) + +utils.print_all_opt_schedules(backend, strategy) +utils.print_exhaustive_samples(backend, strategy, 100) + +# CHECK: schedule O0: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O1: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O2: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: schedule O3: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: sample 0: [1, 1, 1, 0] +# CHECK-NEXT: sample 1: [1, 1, 1, 1] +# CHECK-NEXT: sample 2: [1, 1, 1, 2] +# CHECK-NEXT: sample 3: [1, 1, 1, 3] +# CHECK-NEXT: sample 4: [1, 1, 1, 4] +# CHECK-NEXT: sample 5: [1, 1, 1, 5] +# CHECK-NEXT: sample 6: [1, 1, 2, 0] +# CHECK-NEXT: sample 7: [1, 1, 2, 1] +# CHECK-NEXT: sample 8: [1, 1, 2, 2] +# CHECK-NEXT: sample 9: [1, 1, 2, 3] +# CHECK-NEXT: sample 10: [1, 1, 2, 4] +# CHECK-NEXT: sample 11: [1, 1, 2, 5] +# CHECK-NEXT: sample 12: [1, 1, 3, 0] +# CHECK-NEXT: sample 13: [1, 1, 3, 1] +# CHECK-NEXT: sample 14: [1, 1, 3, 2] +# CHECK-NEXT: sample 15: [1, 1, 3, 3] +# CHECK-NEXT: sample 16: [1, 1, 3, 4] +# CHECK-NEXT: sample 17: [1, 1, 3, 5] +# CHECK-NEXT: sample 18: [1, 1, 4, 0] +# CHECK-NEXT: sample 19: [1, 1, 4, 1] +# CHECK-NEXT: sample 20: [1, 1, 4, 2] +# CHECK-NEXT: sample 21: [1, 1, 4, 3] +# CHECK-NEXT: sample 22: [1, 1, 4, 4] +# CHECK-NEXT: sample 23: [1, 1, 4, 5] +# CHECK-NEXT: sample 24: [1, 1, 6, 0] +# CHECK-NEXT: sample 25: [1, 1, 6, 1] +# CHECK-NEXT: sample 26: [1, 1, 6, 2] +# CHECK-NEXT: sample 27: [1, 1, 6, 3] +# CHECK-NEXT: sample 28: [1, 1, 6, 4] +# CHECK-NEXT: sample 29: [1, 1, 6, 5] +# CHECK-NEXT: sample 30: [1, 2, 1, 0] +# CHECK-NEXT: sample 31: [1, 2, 1, 1] +# CHECK-NEXT: sample 32: [1, 2, 1, 2] +# CHECK-NEXT: sample 33: [1, 2, 1, 3] +# CHECK-NEXT: sample 34: [1, 2, 1, 4] +# CHECK-NEXT: sample 35: [1, 2, 1, 5] +# CHECK-NEXT: sample 36: [1, 2, 2, 0] +# CHECK-NEXT: sample 37: [1, 2, 2, 1] +# CHECK-NEXT: sample 38: [1, 2, 2, 2] +# CHECK-NEXT: sample 39: [1, 2, 2, 3] +# CHECK-NEXT: sample 40: [1, 2, 2, 4] +# CHECK-NEXT: sample 41: [1, 2, 2, 5] +# CHECK-NEXT: sample 42: [1, 2, 3, 0] +# CHECK-NEXT: sample 43: [1, 2, 3, 1] +# CHECK-NEXT: sample 44: [1, 2, 3, 2] +# CHECK-NEXT: sample 45: [1, 2, 3, 3] +# CHECK-NEXT: sample 46: [1, 2, 3, 4] +# CHECK-NEXT: sample 47: [1, 2, 3, 5] +# CHECK-NEXT: sample 48: [1, 2, 4, 0] +# CHECK-NEXT: sample 49: [1, 2, 4, 1] +# CHECK-NEXT: sample 50: [1, 2, 4, 2] +# CHECK-NEXT: sample 51: [1, 2, 4, 3] +# CHECK-NEXT: sample 52: [1, 2, 4, 4] +# CHECK-NEXT: sample 53: [1, 2, 4, 5] +# CHECK-NEXT: sample 54: [1, 2, 6, 1] +# CHECK-NEXT: sample 55: [1, 2, 6, 4] +# CHECK-NEXT: sample 56: [1, 4, 1, 0] +# CHECK-NEXT: sample 57: [1, 4, 1, 1] +# CHECK-NEXT: sample 58: [1, 4, 1, 2] +# CHECK-NEXT: sample 59: [1, 4, 1, 3] +# CHECK-NEXT: sample 60: [1, 4, 1, 4] +# CHECK-NEXT: sample 61: [1, 4, 1, 5] +# CHECK-NEXT: sample 62: [1, 4, 2, 0] +# CHECK-NEXT: sample 63: [1, 4, 2, 1] +# CHECK-NEXT: sample 64: [1, 4, 2, 2] +# CHECK-NEXT: sample 65: [1, 4, 2, 3] +# CHECK-NEXT: sample 66: [1, 4, 2, 4] +# CHECK-NEXT: sample 67: [1, 4, 2, 5] +# CHECK-NEXT: sample 68: [1, 4, 3, 1] +# CHECK-NEXT: sample 69: [1, 4, 3, 4] +# CHECK-NEXT: sample 70: [1, 4, 4, 1] +# CHECK-NEXT: sample 71: [1, 4, 4, 4] +# CHECK-NEXT: sample 72: [1, 4, 6, 1] +# CHECK-NEXT: sample 73: [1, 4, 6, 4] +# CHECK-NEXT: sample 74: [1, 8, 1, 0] +# CHECK-NEXT: sample 75: [1, 8, 1, 1] +# CHECK-NEXT: sample 76: [1, 8, 1, 2] +# CHECK-NEXT: sample 77: [1, 8, 1, 3] +# CHECK-NEXT: sample 78: [1, 8, 1, 4] +# CHECK-NEXT: sample 79: [1, 8, 1, 5] +# CHECK-NEXT: sample 80: [1, 8, 2, 1] +# CHECK-NEXT: sample 81: [1, 8, 2, 4] +# CHECK-NEXT: sample 82: [1, 8, 3, 1] +# CHECK-NEXT: sample 83: [1, 8, 3, 4] +# CHECK-NEXT: sample 84: [1, 8, 4, 1] +# CHECK-NEXT: sample 85: [1, 8, 4, 4] +# CHECK-NEXT: sample 86: [1, 8, 6, 1] +# CHECK-NEXT: sample 87: [1, 8, 6, 4] +# CHECK-NEXT: sample 88: [1, 16, 1, 1] +# CHECK-NEXT: sample 89: [1, 16, 1, 4] +# CHECK-NEXT: sample 90: [1, 16, 2, 1] +# CHECK-NEXT: sample 91: [1, 16, 2, 4] +# CHECK-NEXT: sample 92: [1, 16, 3, 1] +# CHECK-NEXT: sample 93: [1, 16, 3, 4] +# CHECK-NEXT: sample 94: [1, 16, 4, 1] +# CHECK-NEXT: sample 95: [1, 16, 4, 4] +# CHECK-NEXT: sample 96: [1, 16, 6, 1] +# CHECK-NEXT: sample 97: [1, 16, 6, 4] +# CHECK-NEXT: sample 98: [1, 32, 1, 1] +# CHECK-NEXT: sample 99: [1, 32, 1, 4] +# CHECK-NEXT: stats {'filtered': 100, 'all': 185} +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py new file mode 100644 index 000000000..ba735de3b --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -0,0 +1,147 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "DDRj": { + "j": {"parallelize": "j_parallel"}, + }, + "DDR": { + "k": {}, + "i": {}, + "explore_axis_order": None, + "pack": [("pack_B", 1, True), ("pack_A", 0, True)], + }, + "L3": { + "j#jL3": {}, + }, + "L2": { + "i#iL2": {}, + }, + "L1": { + "k#kL1": {"unroll": "k_unroll"}, + }, + "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"}}, +} +constraint = ["iR * jR <= 56"] +strategy = Strategy(graph, spec, constraints=constraint, max_unroll=8) + +utils.print_all_opt_schedules(backend, strategy) +utils.print_exhaustive_samples(backend, strategy, 100) + +# CHECK: schedule O0: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O1: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O2: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: schedule O3: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: sample 0: [1, 1, 1, 0] +# CHECK-NEXT: sample 1: [1, 1, 1, 1] +# CHECK-NEXT: sample 2: [1, 1, 1, 2] +# CHECK-NEXT: sample 3: [1, 1, 1, 3] +# CHECK-NEXT: sample 4: [1, 1, 1, 4] +# CHECK-NEXT: sample 5: [1, 1, 1, 5] +# CHECK-NEXT: sample 6: [1, 1, 2, 0] +# CHECK-NEXT: sample 7: [1, 1, 2, 1] +# CHECK-NEXT: sample 8: [1, 1, 2, 2] +# CHECK-NEXT: sample 9: [1, 1, 2, 3] +# CHECK-NEXT: sample 10: [1, 1, 2, 4] +# CHECK-NEXT: sample 11: [1, 1, 2, 5] +# CHECK-NEXT: sample 12: [1, 1, 3, 0] +# CHECK-NEXT: sample 13: [1, 1, 3, 1] +# CHECK-NEXT: sample 14: [1, 1, 3, 2] +# CHECK-NEXT: sample 15: [1, 1, 3, 3] +# CHECK-NEXT: sample 16: [1, 1, 3, 4] +# CHECK-NEXT: sample 17: [1, 1, 3, 5] +# CHECK-NEXT: sample 18: [1, 1, 4, 0] +# CHECK-NEXT: sample 19: [1, 1, 4, 1] +# CHECK-NEXT: sample 20: [1, 1, 4, 2] +# CHECK-NEXT: sample 21: [1, 1, 4, 3] +# CHECK-NEXT: sample 22: [1, 1, 4, 4] +# CHECK-NEXT: sample 23: [1, 1, 4, 5] +# CHECK-NEXT: sample 24: [1, 1, 6, 0] +# CHECK-NEXT: sample 25: [1, 1, 6, 1] +# CHECK-NEXT: sample 26: [1, 1, 6, 2] +# CHECK-NEXT: sample 27: [1, 1, 6, 3] +# CHECK-NEXT: sample 28: [1, 1, 6, 4] +# CHECK-NEXT: sample 29: [1, 1, 6, 5] +# CHECK-NEXT: sample 30: [1, 2, 1, 0] +# CHECK-NEXT: sample 31: [1, 2, 1, 1] +# CHECK-NEXT: sample 32: [1, 2, 1, 2] +# CHECK-NEXT: sample 33: [1, 2, 1, 3] +# CHECK-NEXT: sample 34: [1, 2, 1, 4] +# CHECK-NEXT: sample 35: [1, 2, 1, 5] +# CHECK-NEXT: sample 36: [1, 2, 2, 0] +# CHECK-NEXT: sample 37: [1, 2, 2, 1] +# CHECK-NEXT: sample 38: [1, 2, 2, 2] +# CHECK-NEXT: sample 39: [1, 2, 2, 3] +# CHECK-NEXT: sample 40: [1, 2, 2, 4] +# CHECK-NEXT: sample 41: [1, 2, 2, 5] +# CHECK-NEXT: sample 42: [1, 2, 3, 0] +# CHECK-NEXT: sample 43: [1, 2, 3, 1] +# CHECK-NEXT: sample 44: [1, 2, 3, 2] +# CHECK-NEXT: sample 45: [1, 2, 3, 3] +# CHECK-NEXT: sample 46: [1, 2, 3, 4] +# CHECK-NEXT: sample 47: [1, 2, 3, 5] +# CHECK-NEXT: sample 48: [1, 2, 4, 0] +# CHECK-NEXT: sample 49: [1, 2, 4, 1] +# CHECK-NEXT: sample 50: [1, 2, 4, 2] +# CHECK-NEXT: sample 51: [1, 2, 4, 3] +# CHECK-NEXT: sample 52: [1, 2, 4, 4] +# CHECK-NEXT: sample 53: [1, 2, 4, 5] +# CHECK-NEXT: sample 54: [1, 2, 6, 1] +# CHECK-NEXT: sample 55: [1, 2, 6, 4] +# CHECK-NEXT: sample 56: [1, 4, 1, 0] +# CHECK-NEXT: sample 57: [1, 4, 1, 1] +# CHECK-NEXT: sample 58: [1, 4, 1, 2] +# CHECK-NEXT: sample 59: [1, 4, 1, 3] +# CHECK-NEXT: sample 60: [1, 4, 1, 4] +# CHECK-NEXT: sample 61: [1, 4, 1, 5] +# CHECK-NEXT: sample 62: [1, 4, 2, 0] +# CHECK-NEXT: sample 63: [1, 4, 2, 1] +# CHECK-NEXT: sample 64: [1, 4, 2, 2] +# CHECK-NEXT: sample 65: [1, 4, 2, 3] +# CHECK-NEXT: sample 66: [1, 4, 2, 4] +# CHECK-NEXT: sample 67: [1, 4, 2, 5] +# CHECK-NEXT: sample 68: [1, 4, 3, 1] +# CHECK-NEXT: sample 69: [1, 4, 3, 4] +# CHECK-NEXT: sample 70: [1, 4, 4, 1] +# CHECK-NEXT: sample 71: [1, 4, 4, 4] +# CHECK-NEXT: sample 72: [1, 4, 6, 1] +# CHECK-NEXT: sample 73: [1, 4, 6, 4] +# CHECK-NEXT: sample 74: [1, 8, 1, 0] +# CHECK-NEXT: sample 75: [1, 8, 1, 1] +# CHECK-NEXT: sample 76: [1, 8, 1, 2] +# CHECK-NEXT: sample 77: [1, 8, 1, 3] +# CHECK-NEXT: sample 78: [1, 8, 1, 4] +# CHECK-NEXT: sample 79: [1, 8, 1, 5] +# CHECK-NEXT: sample 80: [1, 8, 2, 1] +# CHECK-NEXT: sample 81: [1, 8, 2, 4] +# CHECK-NEXT: sample 82: [1, 8, 3, 1] +# CHECK-NEXT: sample 83: [1, 8, 3, 4] +# CHECK-NEXT: sample 84: [1, 8, 4, 1] +# CHECK-NEXT: sample 85: [1, 8, 4, 4] +# CHECK-NEXT: sample 86: [1, 8, 6, 1] +# CHECK-NEXT: sample 87: [1, 8, 6, 4] +# CHECK-NEXT: sample 88: [1, 16, 1, 1] +# CHECK-NEXT: sample 89: [1, 16, 1, 4] +# CHECK-NEXT: sample 90: [1, 16, 2, 1] +# CHECK-NEXT: sample 91: [1, 16, 2, 4] +# CHECK-NEXT: sample 92: [1, 16, 3, 1] +# CHECK-NEXT: sample 93: [1, 16, 3, 4] +# CHECK-NEXT: sample 94: [1, 16, 4, 1] +# CHECK-NEXT: sample 95: [1, 16, 4, 4] +# CHECK-NEXT: sample 96: [1, 16, 6, 1] +# CHECK-NEXT: sample 97: [1, 16, 6, 4] +# CHECK-NEXT: sample 98: [1, 32, 1, 1] +# CHECK-NEXT: sample 99: [1, 32, 1, 4] +# CHECK-NEXT: stats {'filtered': 100, 'all': 185} +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] diff --git a/tests/filecheck/search/test_matmul_descript_simple.py b/tests/filecheck/search/test_matmul_descript_simple.py new file mode 100644 index 000000000..86d7e5ae1 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_simple.py @@ -0,0 +1,137 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "L3": { + "k": {}, + "i": {}, + "j": {}, + }, + "L2": { + "i#i1": {}, + "j#j1": {}, + }, + "L1": {"j#j2": {}}, +} +strategy = Strategy(graph, spec, max_unroll=8) + +utils.print_all_opt_schedules(backend, strategy) +utils.print_exhaustive_samples(backend, strategy, 100) + +# CHECK: schedule O0: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O1: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O2: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: schedule O3: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: sample 0: [1, 1, 1, 0] +# CHECK-NEXT: sample 1: [1, 1, 1, 1] +# CHECK-NEXT: sample 2: [1, 1, 1, 2] +# CHECK-NEXT: sample 3: [1, 1, 1, 3] +# CHECK-NEXT: sample 4: [1, 1, 1, 4] +# CHECK-NEXT: sample 5: [1, 1, 1, 5] +# CHECK-NEXT: sample 6: [1, 1, 2, 0] +# CHECK-NEXT: sample 7: [1, 1, 2, 1] +# CHECK-NEXT: sample 8: [1, 1, 2, 2] +# CHECK-NEXT: sample 9: [1, 1, 2, 3] +# CHECK-NEXT: sample 10: [1, 1, 2, 4] +# CHECK-NEXT: sample 11: [1, 1, 2, 5] +# CHECK-NEXT: sample 12: [1, 1, 3, 0] +# CHECK-NEXT: sample 13: [1, 1, 3, 1] +# CHECK-NEXT: sample 14: [1, 1, 3, 2] +# CHECK-NEXT: sample 15: [1, 1, 3, 3] +# CHECK-NEXT: sample 16: [1, 1, 3, 4] +# CHECK-NEXT: sample 17: [1, 1, 3, 5] +# CHECK-NEXT: sample 18: [1, 1, 4, 0] +# CHECK-NEXT: sample 19: [1, 1, 4, 1] +# CHECK-NEXT: sample 20: [1, 1, 4, 2] +# CHECK-NEXT: sample 21: [1, 1, 4, 3] +# CHECK-NEXT: sample 22: [1, 1, 4, 4] +# CHECK-NEXT: sample 23: [1, 1, 4, 5] +# CHECK-NEXT: sample 24: [1, 1, 6, 0] +# CHECK-NEXT: sample 25: [1, 1, 6, 1] +# CHECK-NEXT: sample 26: [1, 1, 6, 2] +# CHECK-NEXT: sample 27: [1, 1, 6, 3] +# CHECK-NEXT: sample 28: [1, 1, 6, 4] +# CHECK-NEXT: sample 29: [1, 1, 6, 5] +# CHECK-NEXT: sample 30: [1, 2, 1, 0] +# CHECK-NEXT: sample 31: [1, 2, 1, 1] +# CHECK-NEXT: sample 32: [1, 2, 1, 2] +# CHECK-NEXT: sample 33: [1, 2, 1, 3] +# CHECK-NEXT: sample 34: [1, 2, 1, 4] +# CHECK-NEXT: sample 35: [1, 2, 1, 5] +# CHECK-NEXT: sample 36: [1, 2, 2, 0] +# CHECK-NEXT: sample 37: [1, 2, 2, 1] +# CHECK-NEXT: sample 38: [1, 2, 2, 2] +# CHECK-NEXT: sample 39: [1, 2, 2, 3] +# CHECK-NEXT: sample 40: [1, 2, 2, 4] +# CHECK-NEXT: sample 41: [1, 2, 2, 5] +# CHECK-NEXT: sample 42: [1, 2, 3, 0] +# CHECK-NEXT: sample 43: [1, 2, 3, 1] +# CHECK-NEXT: sample 44: [1, 2, 3, 2] +# CHECK-NEXT: sample 45: [1, 2, 3, 3] +# CHECK-NEXT: sample 46: [1, 2, 3, 4] +# CHECK-NEXT: sample 47: [1, 2, 3, 5] +# CHECK-NEXT: sample 48: [1, 2, 4, 0] +# CHECK-NEXT: sample 49: [1, 2, 4, 1] +# CHECK-NEXT: sample 50: [1, 2, 4, 2] +# CHECK-NEXT: sample 51: [1, 2, 4, 3] +# CHECK-NEXT: sample 52: [1, 2, 4, 4] +# CHECK-NEXT: sample 53: [1, 2, 4, 5] +# CHECK-NEXT: sample 54: [1, 2, 6, 1] +# CHECK-NEXT: sample 55: [1, 2, 6, 4] +# CHECK-NEXT: sample 56: [1, 4, 1, 0] +# CHECK-NEXT: sample 57: [1, 4, 1, 1] +# CHECK-NEXT: sample 58: [1, 4, 1, 2] +# CHECK-NEXT: sample 59: [1, 4, 1, 3] +# CHECK-NEXT: sample 60: [1, 4, 1, 4] +# CHECK-NEXT: sample 61: [1, 4, 1, 5] +# CHECK-NEXT: sample 62: [1, 4, 2, 0] +# CHECK-NEXT: sample 63: [1, 4, 2, 1] +# CHECK-NEXT: sample 64: [1, 4, 2, 2] +# CHECK-NEXT: sample 65: [1, 4, 2, 3] +# CHECK-NEXT: sample 66: [1, 4, 2, 4] +# CHECK-NEXT: sample 67: [1, 4, 2, 5] +# CHECK-NEXT: sample 68: [1, 4, 3, 1] +# CHECK-NEXT: sample 69: [1, 4, 3, 4] +# CHECK-NEXT: sample 70: [1, 4, 4, 1] +# CHECK-NEXT: sample 71: [1, 4, 4, 4] +# CHECK-NEXT: sample 72: [1, 4, 6, 1] +# CHECK-NEXT: sample 73: [1, 4, 6, 4] +# CHECK-NEXT: sample 74: [1, 8, 1, 0] +# CHECK-NEXT: sample 75: [1, 8, 1, 1] +# CHECK-NEXT: sample 76: [1, 8, 1, 2] +# CHECK-NEXT: sample 77: [1, 8, 1, 3] +# CHECK-NEXT: sample 78: [1, 8, 1, 4] +# CHECK-NEXT: sample 79: [1, 8, 1, 5] +# CHECK-NEXT: sample 80: [1, 8, 2, 1] +# CHECK-NEXT: sample 81: [1, 8, 2, 4] +# CHECK-NEXT: sample 82: [1, 8, 3, 1] +# CHECK-NEXT: sample 83: [1, 8, 3, 4] +# CHECK-NEXT: sample 84: [1, 8, 4, 1] +# CHECK-NEXT: sample 85: [1, 8, 4, 4] +# CHECK-NEXT: sample 86: [1, 8, 6, 1] +# CHECK-NEXT: sample 87: [1, 8, 6, 4] +# CHECK-NEXT: sample 88: [1, 16, 1, 1] +# CHECK-NEXT: sample 89: [1, 16, 1, 4] +# CHECK-NEXT: sample 90: [1, 16, 2, 1] +# CHECK-NEXT: sample 91: [1, 16, 2, 4] +# CHECK-NEXT: sample 92: [1, 16, 3, 1] +# CHECK-NEXT: sample 93: [1, 16, 3, 4] +# CHECK-NEXT: sample 94: [1, 16, 4, 1] +# CHECK-NEXT: sample 95: [1, 16, 4, 4] +# CHECK-NEXT: sample 96: [1, 16, 6, 1] +# CHECK-NEXT: sample 97: [1, 16, 6, 4] +# CHECK-NEXT: sample 98: [1, 32, 1, 1] +# CHECK-NEXT: sample 99: [1, 32, 1, 4] +# CHECK-NEXT: stats {'filtered': 100, 'all': 185} +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py new file mode 100644 index 000000000..7e446ac71 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -0,0 +1,146 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "DDR": { + "j": {}, + "k": {}, + }, + "L1": { + "j#jDDR": {}, + "i[:iT1]": { + "R": { + "i#iR1": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + }, + "i[iT1:]": { + "R": { + "i#iR2": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + }, + }, +} +strategy = Strategy(graph, spec, max_unroll=8) + +utils.print_all_opt_schedules(backend, strategy) +utils.print_exhaustive_samples(backend, strategy, 100) + +# CHECK: schedule O0: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O1: [1, 1, 1, 0] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] +# CHECK-NEXT: schedule O2: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: schedule O3: [1, 1, 1, 1] +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] +# CHECK-NEXT: sample 0: [1, 1, 1, 0] +# CHECK-NEXT: sample 1: [1, 1, 1, 1] +# CHECK-NEXT: sample 2: [1, 1, 1, 2] +# CHECK-NEXT: sample 3: [1, 1, 1, 3] +# CHECK-NEXT: sample 4: [1, 1, 1, 4] +# CHECK-NEXT: sample 5: [1, 1, 1, 5] +# CHECK-NEXT: sample 6: [1, 1, 2, 0] +# CHECK-NEXT: sample 7: [1, 1, 2, 1] +# CHECK-NEXT: sample 8: [1, 1, 2, 2] +# CHECK-NEXT: sample 9: [1, 1, 2, 3] +# CHECK-NEXT: sample 10: [1, 1, 2, 4] +# CHECK-NEXT: sample 11: [1, 1, 2, 5] +# CHECK-NEXT: sample 12: [1, 1, 3, 0] +# CHECK-NEXT: sample 13: [1, 1, 3, 1] +# CHECK-NEXT: sample 14: [1, 1, 3, 2] +# CHECK-NEXT: sample 15: [1, 1, 3, 3] +# CHECK-NEXT: sample 16: [1, 1, 3, 4] +# CHECK-NEXT: sample 17: [1, 1, 3, 5] +# CHECK-NEXT: sample 18: [1, 1, 4, 0] +# CHECK-NEXT: sample 19: [1, 1, 4, 1] +# CHECK-NEXT: sample 20: [1, 1, 4, 2] +# CHECK-NEXT: sample 21: [1, 1, 4, 3] +# CHECK-NEXT: sample 22: [1, 1, 4, 4] +# CHECK-NEXT: sample 23: [1, 1, 4, 5] +# CHECK-NEXT: sample 24: [1, 1, 6, 0] +# CHECK-NEXT: sample 25: [1, 1, 6, 1] +# CHECK-NEXT: sample 26: [1, 1, 6, 2] +# CHECK-NEXT: sample 27: [1, 1, 6, 3] +# CHECK-NEXT: sample 28: [1, 1, 6, 4] +# CHECK-NEXT: sample 29: [1, 1, 6, 5] +# CHECK-NEXT: sample 30: [1, 2, 1, 0] +# CHECK-NEXT: sample 31: [1, 2, 1, 1] +# CHECK-NEXT: sample 32: [1, 2, 1, 2] +# CHECK-NEXT: sample 33: [1, 2, 1, 3] +# CHECK-NEXT: sample 34: [1, 2, 1, 4] +# CHECK-NEXT: sample 35: [1, 2, 1, 5] +# CHECK-NEXT: sample 36: [1, 2, 2, 0] +# CHECK-NEXT: sample 37: [1, 2, 2, 1] +# CHECK-NEXT: sample 38: [1, 2, 2, 2] +# CHECK-NEXT: sample 39: [1, 2, 2, 3] +# CHECK-NEXT: sample 40: [1, 2, 2, 4] +# CHECK-NEXT: sample 41: [1, 2, 2, 5] +# CHECK-NEXT: sample 42: [1, 2, 3, 0] +# CHECK-NEXT: sample 43: [1, 2, 3, 1] +# CHECK-NEXT: sample 44: [1, 2, 3, 2] +# CHECK-NEXT: sample 45: [1, 2, 3, 3] +# CHECK-NEXT: sample 46: [1, 2, 3, 4] +# CHECK-NEXT: sample 47: [1, 2, 3, 5] +# CHECK-NEXT: sample 48: [1, 2, 4, 0] +# CHECK-NEXT: sample 49: [1, 2, 4, 1] +# CHECK-NEXT: sample 50: [1, 2, 4, 2] +# CHECK-NEXT: sample 51: [1, 2, 4, 3] +# CHECK-NEXT: sample 52: [1, 2, 4, 4] +# CHECK-NEXT: sample 53: [1, 2, 4, 5] +# CHECK-NEXT: sample 54: [1, 2, 6, 1] +# CHECK-NEXT: sample 55: [1, 2, 6, 4] +# CHECK-NEXT: sample 56: [1, 4, 1, 0] +# CHECK-NEXT: sample 57: [1, 4, 1, 1] +# CHECK-NEXT: sample 58: [1, 4, 1, 2] +# CHECK-NEXT: sample 59: [1, 4, 1, 3] +# CHECK-NEXT: sample 60: [1, 4, 1, 4] +# CHECK-NEXT: sample 61: [1, 4, 1, 5] +# CHECK-NEXT: sample 62: [1, 4, 2, 0] +# CHECK-NEXT: sample 63: [1, 4, 2, 1] +# CHECK-NEXT: sample 64: [1, 4, 2, 2] +# CHECK-NEXT: sample 65: [1, 4, 2, 3] +# CHECK-NEXT: sample 66: [1, 4, 2, 4] +# CHECK-NEXT: sample 67: [1, 4, 2, 5] +# CHECK-NEXT: sample 68: [1, 4, 3, 1] +# CHECK-NEXT: sample 69: [1, 4, 3, 4] +# CHECK-NEXT: sample 70: [1, 4, 4, 1] +# CHECK-NEXT: sample 71: [1, 4, 4, 4] +# CHECK-NEXT: sample 72: [1, 4, 6, 1] +# CHECK-NEXT: sample 73: [1, 4, 6, 4] +# CHECK-NEXT: sample 74: [1, 8, 1, 0] +# CHECK-NEXT: sample 75: [1, 8, 1, 1] +# CHECK-NEXT: sample 76: [1, 8, 1, 2] +# CHECK-NEXT: sample 77: [1, 8, 1, 3] +# CHECK-NEXT: sample 78: [1, 8, 1, 4] +# CHECK-NEXT: sample 79: [1, 8, 1, 5] +# CHECK-NEXT: sample 80: [1, 8, 2, 1] +# CHECK-NEXT: sample 81: [1, 8, 2, 4] +# CHECK-NEXT: sample 82: [1, 8, 3, 1] +# CHECK-NEXT: sample 83: [1, 8, 3, 4] +# CHECK-NEXT: sample 84: [1, 8, 4, 1] +# CHECK-NEXT: sample 85: [1, 8, 4, 4] +# CHECK-NEXT: sample 86: [1, 8, 6, 1] +# CHECK-NEXT: sample 87: [1, 8, 6, 4] +# CHECK-NEXT: sample 88: [1, 16, 1, 1] +# CHECK-NEXT: sample 89: [1, 16, 1, 4] +# CHECK-NEXT: sample 90: [1, 16, 2, 1] +# CHECK-NEXT: sample 91: [1, 16, 2, 4] +# CHECK-NEXT: sample 92: [1, 16, 3, 1] +# CHECK-NEXT: sample 93: [1, 16, 3, 4] +# CHECK-NEXT: sample 94: [1, 16, 4, 1] +# CHECK-NEXT: sample 95: [1, 16, 4, 4] +# CHECK-NEXT: sample 96: [1, 16, 6, 1] +# CHECK-NEXT: sample 97: [1, 16, 6, 4] +# CHECK-NEXT: sample 98: [1, 32, 1, 1] +# CHECK-NEXT: sample 99: [1, 32, 1, 4] +# CHECK-NEXT: stats {'filtered': 100, 'all': 185} +# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] From b621ebe71b3e5d35e7e386511a5ef7364ee6e389 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Fri, 17 Oct 2025 17:27:19 +0200 Subject: [PATCH 02/23] Some cleanups --- src/xtc/schedules/descript_extend.py | 62 +++++++++++-------- src/xtc/search/strategies.py | 18 ------ .../search/test_matmul_descript_3axes.py | 2 +- .../search/test_matmul_descript_goto.py | 2 +- 4 files changed, 39 insertions(+), 45 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index ab4371c8f..e81294616 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -67,39 +67,53 @@ def apply_sample( flat_schedules = flat_schedules.copy() for schedule in flat_schedules: for k in ["splits", "tiles"]: - for d, s in schedule[k].items(): - for d_, s_ in s.items(): - if isinstance(s_, str): - schedule[k][d][d_] = sample[s_] + for dim, axes in schedule[k].items(): + for level, size in axes.items(): + if isinstance(size, str): + schedule[k][dim][level] = sample[size] for k in ["vectorize", "parallelize"]: - for i, s in enumerate(schedule[k]): - if isinstance(s, Tuple): - s, loop = s - s = sample.get(s, False) - if s is None or s: + for i, axes in enumerate(schedule[k]): + if isinstance(axes, Tuple): + axes, loop = axes + axes = sample.get(axes, False) + if axes is None or axes: schedule[k][i] = loop else: schedule[k].pop(i) - for d, s in schedule["unroll"].items(): - if isinstance(s, str): - val = sample[s] + for axis, size in schedule["unroll"].items(): + if isinstance(size, str): + val = sample[size] if val is None: for s__ in schedule["tiles"].values(): - for d_, s_ in s__.items(): - if d == d_: - val = s_ + for level, size in s__.items(): + if axis == level: + val = size break if val is not None: break - schedule["unroll"][d] = val - for d, axes in schedule["axes"].items(): - d_holder = f"order_{d}" + schedule["unroll"][axis] = val + for dim, packs in schedule["packs"].items(): + for i, (flag, input, pad) in enumerate(packs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule["packs"][dim].pop(i) + continue + if isinstance(input, str): + input = sample.get(input, input) + sample_flag = True + if sample_flag: + schedule["packs"][dim] = (flag, input, pad) + for dim, axes in schedule["axes"].items(): + d_holder = f"order_{dim}" s = sample.get(d_holder, None) if s: sch = {} for a in s: sch[a] = axes[a] - schedule["axes"][d] = sch + schedule["axes"][dim] = sch return flat_schedules def apply_scheduler(self, flat_schedules: list[SchedDict], scheduler: Scheduler): @@ -366,8 +380,8 @@ def _extended_check_splitting_intervals( return (None, None) constraint = f"{x} < {y}" - if isinstance(y, int): - if isinstance(x, int): + if isinstance(x, int): + if isinstance(y, int): if x >= y: raise Exception( f""" @@ -377,10 +391,8 @@ def _extended_check_splitting_intervals( ) else: return (None, y - x) - else: - return (constraint, f"{y} - {x}") - if isinstance(x, int) and x == 0: - return (constraint, f"{y}") + if x == 0: + return (constraint, f"{y}") return (constraint, f"{y} - {x}") def annotate( diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 55e2bb5d1..e386b005f 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -951,27 +951,10 @@ def __init__( graph: Graph, spec: dict[str, dict], constraints: list[str] = [], - vec_size: int = 16, - max_unroll: int = 256, - threads: int = 1, - max_parallelize: int = 1, - **kwargs: Any, ) -> None: self._graph = graph - self._vec_size = vec_size - self._max_unroll = max_unroll - self._threads = threads - # Schedule output operation self._op = graph.outputs_nodes[0].operation self._stats: dict[str, int] = {} - self._parallelize = self._threads > 1 - self._max_parallelize = max_parallelize - self._vectorize = self._vec_size > 1 - self._unroll = self._max_unroll != 0 - # TODO: should go into some machine description - self._arch_vreg_num = kwargs.get("vreg_num", 32) - self._arch_l1_size = kwargs.get("l1_size", 32 * 1024) - self._arch_l2_size = kwargs.get("l2_size", 1024 * 1024) self._axes = list(self._op.dims) self._sizes = self._constant_sizes() descript = DescriptExtend(abstract_axis=self._axes) @@ -1039,7 +1022,6 @@ def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: k=num, silent=True, ) - # print(list(draw.values())[0][0]) return iter(list(draw.values())[0]) @override diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py index 1df85036c..8b3128b3d 100644 --- a/tests/filecheck/search/test_matmul_descript_3axes.py +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -22,7 +22,7 @@ "explore_axis_order": None, }, } -strategy = Strategy(graph, spec, max_unroll=8) +strategy = Strategy(graph, spec) utils.print_all_opt_schedules(backend, strategy) utils.print_exhaustive_samples(backend, strategy, 100) diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py index ba735de3b..e968a4a29 100644 --- a/tests/filecheck/search/test_matmul_descript_goto.py +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -30,7 +30,7 @@ "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"}}, } constraint = ["iR * jR <= 56"] -strategy = Strategy(graph, spec, constraints=constraint, max_unroll=8) +strategy = Strategy(graph, spec, constraints=constraint) utils.print_all_opt_schedules(backend, strategy) utils.print_exhaustive_samples(backend, strategy, 100) From 59690bdda024f49e1f01b96c20bb312a7a60c359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Mon, 20 Oct 2025 12:05:36 +0200 Subject: [PATCH 03/23] buffer_at support --- src/xtc/schedules/descript_extend.py | 47 +++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index e81294616..7eb63638c 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -5,6 +5,7 @@ from typing import Any, Tuple from dataclasses import dataclass import re +from networkx import constraint from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler @@ -106,6 +107,17 @@ def apply_sample( sample_flag = True if sample_flag: schedule["packs"][dim] = (flag, input, pad) + for dim, buffs in schedule["buffers"].items(): + for i, (flag, pad) in enumerate(buffs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule["buffers"][dim].pop(i) + continue + if sample_flag: + schedule["buffers"][dim] = (flag, pad) for dim, axes in schedule["axes"].items(): d_holder = f"order_{dim}" s = sample.get(d_holder, None) @@ -132,13 +144,16 @@ def apply_scheduler(self, flat_schedules: list[SchedDict], scheduler: Scheduler) for _, input, pad in p: scheduler.pack_at(s[-1], input, pad=pad) + b = schedule["buffers"].get(d, None) + if b: + scheduler.buffer_at(s[-1]) + for d, s in schedule["splits"].items(): scheduler.split(d, s, root=root) for d, s in schedule["tiles"].items(): scheduler.tile(d, s, root=root) - # print(interchange) scheduler.interchange(interchange, root=root) scheduler.vectorize(schedule["vectorize"], root=root) scheduler.parallelize(schedule["parallelize"], root=root) @@ -157,6 +172,7 @@ def _flatten_schedule( "root": root, "fusions": {}, "packs": {}, + "buffers": {}, "axis_orders": [], "axes": {}, "splits": {}, @@ -184,25 +200,40 @@ def _flatten_schedule( tree_interchange = {} tree_packs = [] tree_fusion = [] + tree_buff = [] for declaration, val in tree_val.items(): if declaration == "fusion": - # sched["fusions"][tree_declaration] = val tree_fusion.append(val) continue elif declaration == "pack": for val_ in val: if len(val_) != 3: raise Exception(f"Packing {val_} should have 3 parameters.") - param, input, pack = val_ - tree_packs.append((param, input, pack)) + param, input, pad = val_ + tree_packs.append((param, input, pad)) if isinstance(param, str): variables.append(param) constraints.append(f"0 <= {param} <= 1") if isinstance(input, str): raise Exception("Packing input cannot be a variable.") - if isinstance(pack, str): - variables.append(pack) - constraints.append(f"0 <= {pack} <= 1") + if isinstance(pad, str): + variables.append(pad) + constraints.append(f"0 <= {pad} <= 1") + continue + elif declaration in "buffer": + for val_ in val: + if len(val_) != 2: + raise Exception( + f"Bufferisation {val_} should have 2 parameters." + ) + param, pad = val_ + tree_buff.append((param, pad)) + if isinstance(param, str): + variables.append(param) + constraints.append(f"0 <= {param} <= 1") + if isinstance(pad, str): + variables.append(pad) + constraints.append(f"0 <= {pad} <= 1") continue elif declaration == "explore_axis_order": sched["axis_orders"].append(tree_declaration) @@ -316,6 +347,8 @@ def _flatten_schedule( sched["packs"][tree_declaration] = tree_packs if len(tree_fusion) > 0: sched["fusions"][tree_declaration] = tree_fusion + if len(tree_buff) > 0: + sched["buffers"][tree_declaration] = tree_buff for v in tree_interchange.values(): interchange += v From c90075cd5b1140286324dec13fecc789ed93012f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Tue, 21 Oct 2025 11:40:35 +0200 Subject: [PATCH 04/23] Fixed constraint generation for splits and unfolds --- src/xtc/cli/mlir_loop.py | 1 + src/xtc/schedules/descript_extend.py | 142 +++++++++++++++++++++++---- src/xtc/search/strategies.py | 28 ++++-- 3 files changed, 141 insertions(+), 30 deletions(-) diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index fbc8746a2..95978b7a0 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -140,6 +140,7 @@ def build_node_scheduler( scheduler=scheduler, node_name=node_name, abstract_axis=scheduler.backend.dims, + abstract_axis_sizes={a: 1 for a in scheduler.backend.dims}, spec=normal_schedule, ) else: diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 7eb63638c..cdd25295d 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -5,7 +5,6 @@ from typing import Any, Tuple from dataclasses import dataclass import re -from networkx import constraint from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler @@ -17,15 +16,20 @@ def descript_extend_scheduler( scheduler: Scheduler, node_name: str, abstract_axis: list[str], + abstract_axis_sizes: dict[str, int], spec: dict[str, dict], sample: dict[str, Any] = {}, ): - descript = DescriptExtend(abstract_axis=abstract_axis) + descript = DescriptExtend( + abstract_axis=abstract_axis, abstract_axis_sizes=abstract_axis_sizes + ) descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) @dataclass(frozen=True) class DescriptExtend(Descript): + abstract_axis_sizes: dict[str, int] + @override def apply( self, @@ -58,8 +62,76 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict]): axis_orders = schedule["axis_orders"] for axis in axis_orders: orders[axis] = schedule["axes"][axis] + + for axis in self.abstract_axis: + all_axis_constraints = [] + for schedule in flat_schedules: + for sched in schedule["sizes"][axis]: + if len(sched) > 1: + all_axis_constraints.append(sched) + axis_constraints = [] + i = 0 + while i < len(all_axis_constraints): + sched = all_axis_constraints[i] + # print(all_axis_constraints, i, sched) + # print(sched) + if isinstance(sched[0], int): + axis_constraints.append(sched) + all_axis_constraints.pop(i) + else: + i += 1 + # print(axis, axis_constraints, all_axis_constraints) + while len(all_axis_constraints) > 0: + i = 0 + axis_constraints_acc = [] + flag_flag = False + while i < len(all_axis_constraints): + sched = all_axis_constraints[i] + flag = False + for constraint in axis_constraints: + if sched[0] == constraint[-1]: + axis_constraints_acc.append(constraint + sched[1:]) + flag = True + if flag: + all_axis_constraints.pop(i) + flag_flag = True + else: + i += 1 + if flag_flag: + axis_constraints = axis_constraints_acc + # print(axis, axis_constraints, all_axis_constraints) + for constraint in axis_constraints: + constraint.reverse() + constraint_str = "" + var_flag = False + if isinstance(constraint[0], str): + constraint_str = "1 || " + for size in constraint[:-1]: + var_flag = var_flag or isinstance(size, str) + constraint_str += f"{size} || " + constraint_str += str(constraint[-1]) + if var_flag: + constraints.insert(0, constraint_str) + + # flag = False + # for constraint in axis_constraints: + # if sched[0] == constraint[-1]: + # axis_constraints_acc.append(constraint + sched[1:]) + # flag = True + # else: + # axis_constraints_acc.append(constraint) + # axis_constraints = axis_constraints_acc + # if flag: + # all_axis_constraints.pop(i) + # i = 0 + # else: + # i += 1 + # print(all_axis_constraints, i) + # print(axis, axis_constraints) + variables = list(dict.fromkeys(variables)) constraints = list(dict.fromkeys(constraints)) + # print(constraints) return (flat_schedules, variables, constraints, axes, orders) def apply_sample( @@ -106,7 +178,7 @@ def apply_sample( input = sample.get(input, input) sample_flag = True if sample_flag: - schedule["packs"][dim] = (flag, input, pad) + schedule["packs"][dim][i] = (flag, input, pad) for dim, buffs in schedule["buffers"].items(): for i, (flag, pad) in enumerate(buffs): sample_flag = False @@ -117,7 +189,7 @@ def apply_sample( schedule["buffers"][dim].pop(i) continue if sample_flag: - schedule["buffers"][dim] = (flag, pad) + schedule["buffers"][dim][i] = (flag, pad) for dim, axes in schedule["axes"].items(): d_holder = f"order_{dim}" s = sample.get(d_holder, None) @@ -176,6 +248,7 @@ def _flatten_schedule( "axis_orders": [], "axes": {}, "splits": {}, + "sizes": {}, "tiles": {a: {} for a in self.abstract_axis}, "interchange": [], "vectorize": [], @@ -188,7 +261,12 @@ def _flatten_schedule( if tile_sizes: axes_sizes: dict[str, int | str] = tile_sizes else: - axes_sizes = {a: f"[{a}]" for a in self.abstract_axis} + axes_sizes = {a: v for a, v in self.abstract_axis_sizes.items()} + # print(axes_sizes) + sched_sizes = {} + for a, v in axes_sizes.items(): + sched["sizes"][a] = [] + sched_sizes[a] = [v] sizes: dict[str, int | str | None] = {} previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = head @@ -269,16 +347,17 @@ def _flatten_schedule( tree_interchange[axis_name].append(new_dim_name) else: tree_interchange[axis_name] = [new_dim_name] - inner_size = inner_size if inner_size else f"{current_size} - {x}" - inner_size_holder = f"{axis_name}_{new_dim_index}_" - constraints.append(f"{inner_size_holder} == {inner_size}") - axes_sizes[axis_name] = inner_size_holder + inner_size = ( + inner_size if inner_size else eval(f"{current_size} - {x}") + ) + axes_sizes[axis_name] = inner_size + # sched["sizes"][axis_name].append(inner_size) if lam: if isinstance(y, str): variables.append(y) constraints.append(lam) - constraints.append(f"1 || {y} || {current_size}") + # constraints.append(f"1 || {y} || {current_size}") # Fetch the schedule associated with the new dimension next_schedule = val @@ -290,6 +369,11 @@ def _flatten_schedule( head=[axis_name], ) axes_sizes[axis_name] = current_size + + # for a, v in inner_scheds[0]["sizes"].items(): + # if a != axis_name: + # inner_scheds[0]["sizes"][a] = {} + # sched["sizes"][a] += v recursive_scheds += inner_scheds continue elif "#" in declaration: @@ -301,15 +385,16 @@ def _flatten_schedule( else: loop_size = tile_size variables.append(tile_size) - constraints.append( - f"1 || {tile_size} || {axes_sizes[axis_name]}" - ) + # constraints.append( + # f"1 || {tile_size} || {axes_sizes[axis_name]}" + # ) if not loop_size: raise Exception( f"Invalid tile size: '{tile_size}' in {declaration}" ) axes_sizes[axis_name] = loop_size + sched_sizes[axis_name].append(loop_size) tile_num = len(sched["tiles"][axis_name]) loop_name = f"{axis_name}{tile_num}" sched["tiles"][axis_name][loop_name] = loop_size @@ -321,6 +406,7 @@ def _flatten_schedule( tree_interchange[axis_name] = [loop_name] elif declaration in self.abstract_axis: loop_name = declaration + axis_name = loop_name if loop_name in tree_interchange: raise Exception( f""" @@ -338,9 +424,11 @@ def _flatten_schedule( self.annotate( loop_name=loop_name, + axis_name=axis_name, sizes=sizes, annotations=val, sched=sched, + sched_sizes=sched_sizes[axis_name], ) sched["axes"][tree_declaration] = tree_interchange if len(tree_packs) > 0: @@ -365,6 +453,16 @@ def _flatten_schedule( sched["interchange"] = interchange sched["variables"] = variables + sched["variables"] sched["constraints"] = constraints + sched["constraints"] + # print(sched_sizes, sched["sizes"]) + for a in self.abstract_axis: + flag = True + for sched_ in sched["sizes"][a]: + if set(sched_sizes[a]) <= set(sched_): + flag = False + break + if flag: + sched["sizes"][a] = [sched_sizes[a]] + sched["sizes"][a] + # print(sched["sizes"]) return [sched] + recursive_scheds def _extended_check_splitting_intervals( @@ -412,7 +510,7 @@ def _extended_check_splitting_intervals( if y is None: return (None, None) - constraint = f"{x} < {y}" + # constraint = f"{x} < {y}" if isinstance(x, int): if isinstance(y, int): if x >= y: @@ -424,16 +522,22 @@ def _extended_check_splitting_intervals( ) else: return (None, y - x) - if x == 0: - return (constraint, f"{y}") - return (constraint, f"{y} - {x}") + raise Exception(f""" + Arguments for the split must be ints for now. + ({x} or {y} on axis {axis_name}) + """) + # if x == 0: + # return (constraint, f"{y}") + # return (constraint, f"{y} - {x}") def annotate( self, loop_name: str, + axis_name: str, sizes: dict[str, int | str | None], annotations: dict[str, Any], sched: dict[str, Any], + sched_sizes: list[int | str], ): for instr, param in annotations.items(): assert isinstance(instr, str) @@ -445,9 +549,7 @@ def annotate( ufactor = param if isinstance(param, str): sched["variables"].append(param) - sched["constraints"].append( - f"1 || {param} || {sizes[loop_name]}" - ) + sched["sizes"][axis_name].append(sched_sizes + [ufactor]) sched["unroll"][loop_name] = ufactor case "vectorize": diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index e386b005f..3cef6cda3 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -951,14 +951,18 @@ def __init__( graph: Graph, spec: dict[str, dict], constraints: list[str] = [], + initialize: bool = True, ) -> None: self._graph = graph self._op = graph.outputs_nodes[0].operation self._stats: dict[str, int] = {} self._axes = list(self._op.dims) self._sizes = self._constant_sizes() - descript = DescriptExtend(abstract_axis=self._axes) + descript = DescriptExtend( + abstract_axis=self._axes, abstract_axis_sizes=dict(self._sizes) + ) self._descript = descript + self._initialized = False input_constraints = constraints self._flat_schedules, self._sample_names, constraints, axes, orders = ( descript.flatten_schedule(node_name=DEFAULT_ROOT, spec=spec) @@ -970,29 +974,32 @@ def __init__( self._axes_names = {} for a, v in axes.items(): self._axes_names[a] = v - self._orders = {} - order_constraints = [] + self._orders: dict[str, list] = {} + order_constraints: list[str] = [] for a, v in orders.items(): + assert isinstance(v, dict) permutation = list(itertools.permutations(v)) a_holder = f"order_{a}" self._orders[a_holder] = permutation order_constraints.append(f"0 <= {a_holder} <= {len(permutation) - 1}") - constraints = constraints + input_constraints + order_constraints - # print(constraints) - constraints = constraints_from_str(constraints, silent=True) - # print(constraints) + self._constraints = constraints + input_constraints + order_constraints + if initialize: + self._initialize() + + def _initialize(self): + if self._initialized: + return + constraints = constraints_from_str(self._constraints, silent=True) properties, constraints = hypergraph(constraints, silent=True) - # print(properties, constraints) methods = solve_with_z3( sampler_variables.keys(), properties, constraints, silent=True ) - # print(methods) enumerations = execute_static(methods, properties, constraints, silent=True) - # print(enumerations) self._properties = properties self._constraints = constraints self._methods = methods self._enumerations = enumerations + self._initialized = True @property @override @@ -1014,6 +1021,7 @@ def generate(self, scheduler: Scheduler, sample: Sample) -> None: @override def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: + self._initialize() draw = execute_dynamic( self._methods, self._properties, From a91fb8b6000ea4579dcd9aee21897ebff8940ff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Tue, 21 Oct 2025 11:41:04 +0200 Subject: [PATCH 05/23] Updates to tests and requirements --- requirements.txt | 2 + ...test_matmul_descript_extend_mlir_sample.py | 240 ++++++------ .../test_matmul_descript_extend_mlir_split.py | 364 ++++++++++-------- ...atmul_descript_extend_mlir_split_sample.py | 212 ---------- .../test_matmul_descript_extend_tvm_goto.py | 211 +++++----- ...est_matmul_descript_extend_tvm_strategy.py | 180 +++------ .../search/test_matmul_descript_3axes.py | 118 +----- .../search/test_matmul_descript_goto.py | 116 +----- .../search/test_matmul_descript_simple.py | 116 +----- .../search/test_matmul_descript_split.py | 131 +------ 10 files changed, 522 insertions(+), 1168 deletions(-) delete mode 100644 tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py diff --git a/requirements.txt b/requirements.txt index 38c09cafa..91f4efa67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ typing_extensions xdsl~=0.50.0 pyyaml scikit-learn +networkx +sympy diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py index 7240455cf..2dc8a5a8a 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py @@ -17,6 +17,7 @@ impl = Backend(graph, always_vectorize=False, no_alias=True) sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} descript_extend_scheduler( scheduler=sch, node_name="C", @@ -32,6 +33,7 @@ "j#j_inner": {"vectorize": "j_vectorize"}, }, }, + abstract_axis_sizes=axes_sizes, sample={"i_inner": 2, "j_inner": 16, "i_unroll": None, "j_vectorize": None}, ) @@ -48,122 +50,122 @@ res = executor.execute() print(f"CODE: {res}") -# CHECK: // -----// IR Dump Before transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) -# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { -# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { -# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op -# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () -# CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: // -----// IR Dump After transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> -# CHECK-NEXT: %c16 = arith.constant 16 : index -# CHECK-NEXT: %c2 = arith.constant 2 : index -# CHECK-NEXT: %c512 = arith.constant 512 : index -# CHECK-NEXT: %c32 = arith.constant 32 : index -# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: %c0 = arith.constant 0 : index -# CHECK-NEXT: %c4 = arith.constant 4 : index -# CHECK-NEXT: %c1 = arith.constant 1 : index -# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { -# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { -# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) -# CHECK-NEXT: } {"./j"} -# CHECK-NEXT: } {"./i"} -# CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { -# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { -# CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { -# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %c2_7 = arith.constant 2 : index -# CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> -# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> -# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %c1_10 = arith.constant 1 : index -# CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index -# CHECK-NEXT: %10 = arith.addi %c0, %9 : index -# CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> -# CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> -# CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: } {"C/j"} -# CHECK-NEXT: } {"C/i"} -# CHECK-NEXT: } {"C/k"} -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: graph: -# CHECK-NEXT: name: matmul -# CHECK-NEXT: inputs: -# CHECK-NEXT: - %0 : 4x512xfloat32 -# CHECK-NEXT: - %1 : 512x32xfloat32 -# CHECK-NEXT: outputs: -# CHECK-NEXT: - %2 : 4x32xfloat32 -# CHECK-NEXT: nodes: -# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] -# CHECK-NEXT: -# CHECK-NEXT: CODE: 0 +#CHECK: // -----// IR Dump Before transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op +#CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: // -----// IR Dump After transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +#CHECK-NEXT: %c16 = arith.constant 16 : index +#CHECK-NEXT: %c2 = arith.constant 2 : index +#CHECK-NEXT: %c512 = arith.constant 512 : index +#CHECK-NEXT: %c32 = arith.constant 32 : index +#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: %c0 = arith.constant 0 : index +#CHECK-NEXT: %c4 = arith.constant 4 : index +#CHECK-NEXT: %c1 = arith.constant 1 : index +#CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"./j"} +#CHECK-NEXT: } {"./i"} +#CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { +#CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { +#CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %c2_7 = arith.constant 2 : index +#CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> +#CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> +#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %c1_10 = arith.constant 1 : index +#CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index +#CHECK-NEXT: %10 = arith.addi %c0, %9 : index +#CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> +#CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> +#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/j"} +#CHECK-NEXT: } {"C/i"} +#CHECK-NEXT: } {"C/k"} +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 4x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 4x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +#CHECK-NEXT: +#CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py index 21cd7a391..1bd14f14e 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py @@ -17,10 +17,12 @@ impl = Backend(graph, always_vectorize=False, no_alias=True) sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} descript_extend_scheduler( scheduler=sch, node_name="C", abstract_axis=["i", "j", "k"], + abstract_axis_sizes=axes_sizes, spec={ "DDR": { "j": {}, @@ -28,13 +30,13 @@ }, "L2": { "j#jDDR": {}, - "i[:iT1]": { + "i[:4]": { "R": { "i#iR1": {"unroll": None}, "j#jR": {"vectorize": None}, }, }, - "i[iT1:]": { + "i[4:]": { "R": { "i#iR2": {"unroll": None}, "j#jR": {"vectorize": None}, @@ -42,7 +44,7 @@ }, }, }, - sample={"jDDR": 16, "jR": 4, "iR1": 2, "iR2": 4, "iT1": 4}, + sample={"jDDR": 16, "jR": 4, "iR1": 2, "iR2": 4}, ) sched = sch.schedule() @@ -58,158 +60,204 @@ res = executor.execute() print(f"CODE: {res}") -# CHECK: // -----// IR Dump Before transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) -# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { -# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { -# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op -# CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_4 after 8 {dimension = 0 : i64} : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %first tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i0" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/j0" : !transform.any_op -# CHECK-NEXT: %3 = transform.get_parent_op %loops_7 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op -# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () -# CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %5 { -# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %5 { -# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op -# CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %7 { -# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %7 { -# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: // -----// IR Dump After transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> -# CHECK-NEXT: %c4 = arith.constant 4 : index -# CHECK-NEXT: %c2 = arith.constant 2 : index -# CHECK-NEXT: %c8 = arith.constant 8 : index -# CHECK-NEXT: %c512 = arith.constant 512 : index -# CHECK-NEXT: %c32 = arith.constant 32 : index -# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: %c0 = arith.constant 0 : index -# CHECK-NEXT: %c16 = arith.constant 16 : index -# CHECK-NEXT: %c1 = arith.constant 1 : index -# CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { -# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { -# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) -# CHECK-NEXT: } {"./j"} -# CHECK-NEXT: } {"./i"} -# CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { -# CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> -# CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { -# CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_5 = memref.subview %subview_3[0, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c4 { -# CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { -# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_9, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_10 : memref<2x16xf32, strided<[32, 1], offset: ?>>) -# CHECK-NEXT: } {"C/i[0]/j0"} -# CHECK-NEXT: %0 = arith.addi %arg5, %c2 : index -# CHECK-NEXT: %subview_11 = memref.subview %subview_5[%0, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%0, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { -# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_11, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_12 : memref<2x16xf32, strided<[32, 1], offset: ?>>) -# CHECK-NEXT: } {"C/i[0]/j0"} -# CHECK-NEXT: } {"C/i[0]/i0"} -# CHECK-NEXT: %subview_7 = memref.subview %subview_3[8, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_8 = memref.subview %subview_2[8, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c2 { -# CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %0 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %1 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %2 = vector.transfer_read %subview_10[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> -# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> -# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %8, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %9 = arith.addi %arg5, %c1 : index -# CHECK-NEXT: %subview_11 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_12 = memref.subview %subview_8[%9, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %10 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %11 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %12 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> -# CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> -# CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %18, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: } {"C/i[1]/i0"} -# CHECK-NEXT: } {"C/k"} -# CHECK-NEXT: } {"C/j"} -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: graph: -# CHECK-NEXT: name: matmul -# CHECK-NEXT: inputs: -# CHECK-NEXT: - %0 : 16x512xfloat32 -# CHECK-NEXT: - %1 : 512x32xfloat32 -# CHECK-NEXT: outputs: -# CHECK-NEXT: - %2 : 16x32xfloat32 -# CHECK-NEXT: nodes: -# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] -# CHECK-NEXT: -# CHECK-NEXT: CODE: 0 +#CHECK: // -----// IR Dump Before transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_7 "C/j0" : !transform.any_op +#CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_6 after 4 {dimension = 0 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %first tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i0" : !transform.any_op +#CHECK-NEXT: %3 = transform.get_parent_op %loops_9 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op +#CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +#CHECK-NEXT: transform.apply_patterns to %5 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %5 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 4 : i64} : !transform.any_op +#CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: // -----// IR Dump After transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %c3 = arith.constant 3 : index +#CHECK-NEXT: %c12 = arith.constant 12 : index +#CHECK-NEXT: %c2 = arith.constant 2 : index +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x4xf32> +#CHECK-NEXT: %c4 = arith.constant 4 : index +#CHECK-NEXT: %c512 = arith.constant 512 : index +#CHECK-NEXT: %c32 = arith.constant 32 : index +#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: %c0 = arith.constant 0 : index +#CHECK-NEXT: %c16 = arith.constant 16 : index +#CHECK-NEXT: %c1 = arith.constant 1 : index +#CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"./j"} +#CHECK-NEXT: } {"./i"} +#CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { +#CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> +#CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { +#CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c16 step %c4 { +#CHECK-NEXT: %subview_5 = memref.subview %subview_4[0, %arg5] [1, 4] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, %arg5] [16, 4] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<16x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_7 = memref.subview %subview_3[0, 0] [4, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_8 = memref.subview %subview_6[0, 0] [4, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<4x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c4 step %c2 { +#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%arg6, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%arg6, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %0 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.extract %1[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<4xf32> +#CHECK-NEXT: %6 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<4xf32> +#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %8, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %9 = arith.addi %arg6, %c1 : index +#CHECK-NEXT: %subview_13 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_8[%9, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_13[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[0]/i0"} +#CHECK-NEXT: %subview_9 = memref.subview %subview_3[4, 0] [12, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<12x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_6[4, 0] [12, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<12x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c12 step %c4 { +#CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %0 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.extract %1[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<4xf32> +#CHECK-NEXT: %6 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<4xf32> +#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %8, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %9 = arith.addi %arg6, %c1 : index +#CHECK-NEXT: %subview_13 = memref.subview %subview_9[%9, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_10[%9, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_13[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %19 = arith.addi %arg6, %c2 : index +#CHECK-NEXT: %subview_15 = memref.subview %subview_9[%19, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_10[%19, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %20 = vector.transfer_read %subview_15[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %21 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %22 = vector.transfer_read %subview_16[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %23 = vector.extract %21[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %24 = vector.extract %20[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %25 = vector.broadcast %24 : f32 to vector<4xf32> +#CHECK-NEXT: %26 = vector.extract %22[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %27 = vector.fma %25, %23, %26 : vector<4xf32> +#CHECK-NEXT: %28 = vector.insert %27, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %28, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %29 = arith.addi %arg6, %c3 : index +#CHECK-NEXT: %subview_17 = memref.subview %subview_9[%29, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_18 = memref.subview %subview_10[%29, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %30 = vector.transfer_read %subview_17[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %31 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %32 = vector.transfer_read %subview_18[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %33 = vector.extract %31[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %34 = vector.extract %30[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %35 = vector.broadcast %34 : f32 to vector<4xf32> +#CHECK-NEXT: %36 = vector.extract %32[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %37 = vector.fma %35, %33, %36 : vector<4xf32> +#CHECK-NEXT: %38 = vector.insert %37, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %38, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[1]/i0"} +#CHECK-NEXT: } {"C/j0"} +#CHECK-NEXT: } {"C/k"} +#CHECK-NEXT: } {"C/j"} +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-NEXT: +#CHECK-NEXT: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 16x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 16x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] +#CHECK-NEXT: +#CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py deleted file mode 100644 index db11dfa29..000000000 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split_sample.py +++ /dev/null @@ -1,212 +0,0 @@ -# RUN: python %s 2>&1 | filecheck %s - -import xtc.graphs.xtc.op as O -from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend -from xtc.schedules.descript_extend import descript_extend_scheduler - -I, J, K, dtype = 16, 32, 512, "float32" -a = O.tensor((I, K), dtype, name="A") -b = O.tensor((K, J), dtype, name="B") - -with O.graph(name="matmul") as gb: - O.matmul(a, b, name="C") - -graph = gb.graph -print(graph) - -impl = Backend(graph, always_vectorize=False, no_alias=True) - -sch = impl.get_scheduler() -descript_extend_scheduler( - scheduler=sch, - node_name="C", - abstract_axis=["i", "j", "k"], - spec={ - "DDR": { - "j": {}, - "k": {}, - "i[:i_split]": { - "Rr": { - "i#2": {"unroll": None}, - "j#16": {"vectorize": None}, - }, - }, - "i[i_split:]": { - "Rl": { - "i#2": {"unroll": None}, - "j#16": {"vectorize": None}, - }, - }, - }, - }, - sample={"i_split": 8}, -) - -sched = sch.schedule() - -comp = impl.get_compiler( - shared_lib=True, - dump_file="matmul_descript_extend_mlir_split_sample", - print_source_ir=True, - print_transformed_ir=True, -) -module = comp.compile(sched) -executor = module.get_executor(validate=True) -res = executor.execute() -print(f"CODE: {res}") - -#CHECK: // -----// IR Dump Before transform //----- // -#CHECK-NEXT: module attributes {transform.with_named_sequence} { -#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { -#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) -#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) -#CHECK-NEXT: return -#CHECK-NEXT: } -#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { -#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op -#CHECK-NEXT: transform.yield -#CHECK-NEXT: } -#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { -#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op -#CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_4 after 8 {dimension = 0 : i64} : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %first tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_7 "C/i[0]/i0" : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/j0" : !transform.any_op -#CHECK-NEXT: %3 = transform.get_parent_op %loops_7 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %3 { -#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %3 { -#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op -#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () -#CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %5 { -#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %5 { -#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op -#CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %7 { -#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %7 { -#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.yield -#CHECK-NEXT: } -#CHECK-NEXT: } -#CHECK-NEXT: -#CHECK-NEXT: // -----// IR Dump After transform //----- // -#CHECK-NEXT: module attributes {transform.with_named_sequence} { -#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { -#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> -#CHECK-NEXT: %c4 = arith.constant 4 : index -#CHECK-NEXT: %c2 = arith.constant 2 : index -#CHECK-NEXT: %c8 = arith.constant 8 : index -#CHECK-NEXT: %c512 = arith.constant 512 : index -#CHECK-NEXT: %c32 = arith.constant 32 : index -#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -#CHECK-NEXT: %c0 = arith.constant 0 : index -#CHECK-NEXT: %c16 = arith.constant 16 : index -#CHECK-NEXT: %c1 = arith.constant 1 : index -#CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { -#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { -#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) -#CHECK-NEXT: } {"./j"} -#CHECK-NEXT: } {"./i"} -#CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { -#CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> -#CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { -#CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %subview_5 = memref.subview %subview_3[0, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c4 { -#CHECK-NEXT: %subview_9 = memref.subview %subview_5[%arg5, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_10 = memref.subview %subview_6[%arg5, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { -#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_9, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_10 : memref<2x16xf32, strided<[32, 1], offset: ?>>) -#CHECK-NEXT: } {"C/i[0]/j0"} -#CHECK-NEXT: %0 = arith.addi %arg5, %c2 : index -#CHECK-NEXT: %subview_11 = memref.subview %subview_5[%0, 0] [2, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_12 = memref.subview %subview_6[%0, 0] [2, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg6 = %c0 to %c16 step %c16 { -#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_11, %subview_4 : memref<2x1xf32, strided<[512, 1], offset: ?>>, memref<1x16xf32, strided<[32, 1], offset: ?>>) outs(%subview_12 : memref<2x16xf32, strided<[32, 1], offset: ?>>) -#CHECK-NEXT: } {"C/i[0]/j0"} -#CHECK-NEXT: } {"C/i[0]/i0"} -#CHECK-NEXT: %subview_7 = memref.subview %subview_3[8, 0] [8, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<8x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_8 = memref.subview %subview_2[8, 0] [8, 16] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<8x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: scf.for %arg5 = %c0 to %c8 step %c2 { -#CHECK-NEXT: %subview_9 = memref.subview %subview_7[%arg5, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_10 = memref.subview %subview_8[%arg5, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %0 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %1 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %2 = vector.transfer_read %subview_10[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> -#CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> -#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> -#CHECK-NEXT: vector.transfer_write %8, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %9 = arith.addi %arg5, %c1 : index -#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<8x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%9, 0] [1, 16] [1, 1] : memref<8x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %10 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %11 = vector.transfer_read %subview_4[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> -#CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> -#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> -#CHECK-NEXT: vector.transfer_write %18, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: } {"C/i[1]/i0"} -#CHECK-NEXT: } {"C/k"} -#CHECK-NEXT: } {"C/j"} -#CHECK-NEXT: return -#CHECK-NEXT: } -#CHECK-NEXT: } -#CHECK-NEXT: -#CHECK-NEXT: graph: -#CHECK-NEXT: name: matmul -#CHECK-NEXT: inputs: -#CHECK-NEXT: - %0 : 16x512xfloat32 -#CHECK-NEXT: - %1 : 512x32xfloat32 -#CHECK-NEXT: outputs: -#CHECK-NEXT: - %2 : 16x32xfloat32 -#CHECK-NEXT: nodes: -#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] -#CHECK-NEXT: -#CHECK-NEXT: CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py index 2c6b57231..d9644b627 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py @@ -18,10 +18,12 @@ impl = Backend(graph, always_vectorize=False, no_alias=True) sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} descript_extend_scheduler( scheduler=sch, node_name="C", abstract_axis=["i", "j", "k"], + abstract_axis_sizes=axes_sizes, spec={ "DDR": { "j": {"parallelize": "par"}, @@ -43,12 +45,19 @@ "L1": { "k#kL1": {"unroll": "k_unroll"}, }, - "R": { - "i#iR": {"unroll": None}, - "j#jR": {"vectorize": None} - }, + "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": None}}, + }, + sample={ + "par": None, + "jL3": 36, + "iL2": 128, + "kL1": 16, + "k_unroll": 2, + "iR": 2, + "jR": 6, + "pack_B": None, + "pack_A": None, }, - sample={"par": None, "jL3": 36, "iL2": 128, "kL1": 16, "k_unroll": 2, "iR": 2, "jR": 6, "pack_B": None, "pack_A": None}, ) sched = sch.schedule() @@ -64,99 +73,99 @@ res = executor.execute() print(f"CODE: {res}") -#CHECK: graph: -#CHECK-NEXT: name: matmul -#CHECK-NEXT: inputs: -#CHECK-NEXT: - %0 : 512x512xfloat32 -#CHECK-NEXT: - %1 : 512x512xfloat32 -#CHECK-NEXT: outputs: -#CHECK-NEXT: - %2 : 512x512xfloat32 -#CHECK-NEXT: nodes: -#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] -#CHECK-NEXT: -#CHECK-NEXT:# from tvm.script import ir as I -#CHECK-NEXT:# from tvm.script import tir as T -#CHECK-NEXT: -#CHECK-NEXT:@I.ir_module -#CHECK-NEXT:class Module: -#CHECK-NEXT: @T.prim_func -#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): -#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) -#CHECK-NEXT: for i, j in T.grid(512, 512): -#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) -#CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) -#CHECK-NEXT: for k in range(512): -#CHECK-NEXT: cse_var_2: T.int32 = i * 512 -#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j -#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) -#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) -#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] -#CHECK-NEXT:INPS = list(obj.values())[:-1] -#CHECK-NEXT:O = obj['C'] -#CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) -#CHECK-NEXT:i, j, = O.op.axis -#CHECK-NEXT:k, = O.op.reduce_axis -#CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) -#CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) -#CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) -#CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) -#CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) -#CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) -#CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1) -#CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) -#CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) -#CHECK-NEXT:sch[O].unroll(__u_k0) -#CHECK-NEXT:sch[O].unroll(i1) -#CHECK-NEXT:sch[O].vectorize(j1) -#CHECK-NEXT:sch[O].parallel(j) -#CHECK-NEXT: -#CHECK-NEXT:# from tvm.script import ir as I -#CHECK-NEXT:# from tvm.script import tir as T -#CHECK-NEXT: -#CHECK-NEXT:@I.ir_module -#CHECK-NEXT:class Module: -#CHECK-NEXT: @T.prim_func -#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): -#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) -#CHECK-NEXT: for j_outer in T.parallel(15): -#CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") -#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) -#CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): -#CHECK-NEXT: for j_inner_inner_init_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): -#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s] = T.float32(0.0) -#CHECK-NEXT: for j_inner_inner_init_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): -#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s + 512] = T.float32(0.0) -#CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): -#CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") -#CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): -#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) -#CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] -#CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): -#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) -#CHECK-NEXT: for j_inner_inner_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -#CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s -#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s] -#CHECK-NEXT: for j_inner_inner_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s + 512 -#CHECK-NEXT: C_1[cse_var_4] = C_1[cse_var_4] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s] -#CHECK-NEXT: for j_inner_inner_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -#CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s -#CHECK-NEXT: C_1[cse_var_7] = C_1[cse_var_7] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s + 512] -#CHECK-NEXT: for j_inner_inner_s in range(6): -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512 -#CHECK-NEXT: C_1[cse_var_10] = C_1[cse_var_10] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512] -#CHECK-NEXT:CODE: 0 +# CHECK: graph: +# CHECK-NEXT: name: matmul +# CHECK-NEXT: inputs: +# CHECK-NEXT: - %0 : 512x512xfloat32 +# CHECK-NEXT: - %1 : 512x512xfloat32 +# CHECK-NEXT: outputs: +# CHECK-NEXT: - %2 : 512x512xfloat32 +# CHECK-NEXT: nodes: +# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] +# CHECK-NEXT: +# CHECK-NEXT:# from tvm.script import ir as I +# CHECK-NEXT:# from tvm.script import tir as T +# CHECK-NEXT: +# CHECK-NEXT:@I.ir_module +# CHECK-NEXT:class Module: +# CHECK-NEXT: @T.prim_func +# CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +# CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +# CHECK-NEXT: for i, j in T.grid(512, 512): +# CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +# CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) +# CHECK-NEXT: for k in range(512): +# CHECK-NEXT: cse_var_2: T.int32 = i * 512 +# CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j +# CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +# CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +# CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] +# CHECK-NEXT:INPS = list(obj.values())[:-1] +# CHECK-NEXT:O = obj['C'] +# CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) +# CHECK-NEXT:i, j, = O.op.axis +# CHECK-NEXT:k, = O.op.reduce_axis +# CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) +# CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) +# CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) +# CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) +# CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) +# CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) +# CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1) +# CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) +# CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) +# CHECK-NEXT:sch[O].unroll(__u_k0) +# CHECK-NEXT:sch[O].unroll(i1) +# CHECK-NEXT:sch[O].vectorize(j1) +# CHECK-NEXT:sch[O].parallel(j) +# CHECK-NEXT: +# CHECK-NEXT:# from tvm.script import ir as I +# CHECK-NEXT:# from tvm.script import tir as T +# CHECK-NEXT: +# CHECK-NEXT:@I.ir_module +# CHECK-NEXT:class Module: +# CHECK-NEXT: @T.prim_func +# CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +# CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +# CHECK-NEXT: for j_outer in T.parallel(15): +# CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") +# CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +# CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): +# CHECK-NEXT: for j_inner_inner_init_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): +# CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s] = T.float32(0.0) +# CHECK-NEXT: for j_inner_inner_init_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): +# CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s + 512] = T.float32(0.0) +# CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): +# CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") +# CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): +# CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +# CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] +# CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): +# CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +# CHECK-NEXT: for j_inner_inner_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +# CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 +# CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 +# CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s +# CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s] +# CHECK-NEXT: for j_inner_inner_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +# CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 +# CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 +# CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s + 512 +# CHECK-NEXT: C_1[cse_var_4] = C_1[cse_var_4] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s] +# CHECK-NEXT: for j_inner_inner_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +# CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 +# CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 +# CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s +# CHECK-NEXT: C_1[cse_var_7] = C_1[cse_var_7] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s + 512] +# CHECK-NEXT: for j_inner_inner_s in range(6): +# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): +# CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 +# CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 +# CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512 +# CHECK-NEXT: C_1[cse_var_10] = C_1[cse_var_10] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512] +# CHECK-NEXT:CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py index afbb47615..652e0f329 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py @@ -1,4 +1,4 @@ -# RUN: python %s 2>&1 | filecheck %s +# RUN: python -O %s 2>&1 | filecheck %s import xtc.graphs.xtc.op as O from xtc.backends.tvm import Backend @@ -51,122 +51,62 @@ res = executor.execute() print(f"CODE: {res}") -# CHECK: // -----// IR Dump Before transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) -# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { -# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { -# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -# CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -# CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op -# CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () -# CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -# CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: transform.apply_patterns to %3 { -# CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -# CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -# CHECK-NEXT: } : !transform.any_op -# CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op -# CHECK-NEXT: transform.yield -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: // -----// IR Dump After transform //----- // -# CHECK-NEXT: module attributes {transform.with_named_sequence} { -# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { -# CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> -# CHECK-NEXT: %c16 = arith.constant 16 : index -# CHECK-NEXT: %c2 = arith.constant 2 : index -# CHECK-NEXT: %c512 = arith.constant 512 : index -# CHECK-NEXT: %c32 = arith.constant 32 : index -# CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -# CHECK-NEXT: %c0 = arith.constant 0 : index -# CHECK-NEXT: %c4 = arith.constant 4 : index -# CHECK-NEXT: %c1 = arith.constant 1 : index -# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { -# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { -# CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) -# CHECK-NEXT: } {"./j"} -# CHECK-NEXT: } {"./i"} -# CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { -# CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> -# CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { -# CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { -# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %c2_7 = arith.constant 2 : index -# CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> -# CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> -# CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %c1_10 = arith.constant 1 : index -# CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index -# CHECK-NEXT: %10 = arith.addi %c0, %9 : index -# CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -# CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -# CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -# CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> -# CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> -# CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> -# CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> -# CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> -# CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -# CHECK-NEXT: } {"C/j"} -# CHECK-NEXT: } {"C/i"} -# CHECK-NEXT: } {"C/k"} -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK-NEXT: } -# CHECK-NEXT: -# CHECK-NEXT: graph: -# CHECK-NEXT: name: matmul -# CHECK-NEXT: inputs: -# CHECK-NEXT: - %0 : 4x512xfloat32 -# CHECK-NEXT: - %1 : 512x32xfloat32 -# CHECK-NEXT: outputs: -# CHECK-NEXT: - %2 : 4x32xfloat32 -# CHECK-NEXT: nodes: -# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] -# CHECK-NEXT: -# CHECK-NEXT: CODE: 0 +#CHECK: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 4x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 4x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +#CHECK-NEXT: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-NEXT: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((4, 512), "float32"), _1: T.Buffer((512, 32), "float32"), C: T.Buffer((4, 32), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for i, j in T.grid(4, 32): +#CHECK-NEXT: C_1 = T.Buffer((128,), data=C.data) +#CHECK-NEXT: C_1[i * 32 + j] = T.float32(0.0) +#CHECK-NEXT: for k in range(512): +#CHECK-NEXT: cse_var_1: T.int32 = i * 32 + j +#CHECK-NEXT: _0_1 = T.Buffer((2048,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((16384,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[i * 512 + k] * _1_1[k * 32 + j] +#CHECK-NEXT:O = obj['C'] +#CHECK-NEXT:i, j, = O.op.axis +#CHECK-NEXT:k, = O.op.reduce_axis +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=16) +#CHECK-NEXT:i, i0 = sch[O].split(i, factor=2) +#CHECK-NEXT:sch[O].reorder(k, i, j, j0, i0) +#CHECK-NEXT:sch[O].unroll(i0) +#CHECK-NEXT:sch[O].vectorize(j0) +#CHECK-NEXT: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-NEXT: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((4, 512), "float32"), _1: T.Buffer((512, 32), "float32"), C: T.Buffer((4, 32), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: C_1 = T.Buffer((128,), data=C.data) +#CHECK-NEXT: for i_outer_init, j_outer_init in T.grid(2, 2): +#CHECK-NEXT: cse_var_1: T.int32 = i_outer_init * 64 + j_outer_init * 16 +#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 16] = T.Broadcast(T.float32(0.0), 16) +#CHECK-NEXT: C_1[cse_var_1 + 32:cse_var_1 + 32 + 16] = T.Broadcast(T.float32(0.0), 16) +#CHECK-NEXT: for k, i_outer, j_outer in T.grid(512, 2, 2): +#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 16 +#CHECK-NEXT: cse_var_5: T.int32 = i_outer * 1024 + k +#CHECK-NEXT: cse_var_4: T.int32 = k * 32 + cse_var_6 +#CHECK-NEXT: cse_var_3: T.int32 = i_outer * 64 + cse_var_6 +#CHECK-NEXT: cse_var_2: T.int32 = cse_var_3 + 32 +#CHECK-NEXT: _0_1 = T.Buffer((2048,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((16384,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_3:cse_var_3 + 16] = C_1[cse_var_3:cse_var_3 + 16] + T.Broadcast(_0_1[cse_var_5], 16) * _1_1[cse_var_4:cse_var_4 + 16] +#CHECK-NEXT: C_1[cse_var_2:cse_var_2 + 16] = C_1[cse_var_2:cse_var_2 + 16] + T.Broadcast(_0_1[cse_var_5 + 512], 16) * _1_1[cse_var_4:cse_var_4 + 16] +#CHECK-NEXT:CODE: 0 diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py index 8b3128b3d..cdbab3779 100644 --- a/tests/filecheck/search/test_matmul_descript_3axes.py +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -1,6 +1,6 @@ # RUN: python %s 2>&1 | filecheck %s """ -Test strategy Goto on matmul +Test strategy 3-axis on matmul """ import utils @@ -22,118 +22,8 @@ "explore_axis_order": None, }, } -strategy = Strategy(graph, spec) +strategy = Strategy(graph, spec, initialize=False) -utils.print_all_opt_schedules(backend, strategy) -utils.print_exhaustive_samples(backend, strategy, 100) +print(strategy._constraints) -# CHECK: schedule O0: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O1: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O2: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: schedule O3: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: sample 0: [1, 1, 1, 0] -# CHECK-NEXT: sample 1: [1, 1, 1, 1] -# CHECK-NEXT: sample 2: [1, 1, 1, 2] -# CHECK-NEXT: sample 3: [1, 1, 1, 3] -# CHECK-NEXT: sample 4: [1, 1, 1, 4] -# CHECK-NEXT: sample 5: [1, 1, 1, 5] -# CHECK-NEXT: sample 6: [1, 1, 2, 0] -# CHECK-NEXT: sample 7: [1, 1, 2, 1] -# CHECK-NEXT: sample 8: [1, 1, 2, 2] -# CHECK-NEXT: sample 9: [1, 1, 2, 3] -# CHECK-NEXT: sample 10: [1, 1, 2, 4] -# CHECK-NEXT: sample 11: [1, 1, 2, 5] -# CHECK-NEXT: sample 12: [1, 1, 3, 0] -# CHECK-NEXT: sample 13: [1, 1, 3, 1] -# CHECK-NEXT: sample 14: [1, 1, 3, 2] -# CHECK-NEXT: sample 15: [1, 1, 3, 3] -# CHECK-NEXT: sample 16: [1, 1, 3, 4] -# CHECK-NEXT: sample 17: [1, 1, 3, 5] -# CHECK-NEXT: sample 18: [1, 1, 4, 0] -# CHECK-NEXT: sample 19: [1, 1, 4, 1] -# CHECK-NEXT: sample 20: [1, 1, 4, 2] -# CHECK-NEXT: sample 21: [1, 1, 4, 3] -# CHECK-NEXT: sample 22: [1, 1, 4, 4] -# CHECK-NEXT: sample 23: [1, 1, 4, 5] -# CHECK-NEXT: sample 24: [1, 1, 6, 0] -# CHECK-NEXT: sample 25: [1, 1, 6, 1] -# CHECK-NEXT: sample 26: [1, 1, 6, 2] -# CHECK-NEXT: sample 27: [1, 1, 6, 3] -# CHECK-NEXT: sample 28: [1, 1, 6, 4] -# CHECK-NEXT: sample 29: [1, 1, 6, 5] -# CHECK-NEXT: sample 30: [1, 2, 1, 0] -# CHECK-NEXT: sample 31: [1, 2, 1, 1] -# CHECK-NEXT: sample 32: [1, 2, 1, 2] -# CHECK-NEXT: sample 33: [1, 2, 1, 3] -# CHECK-NEXT: sample 34: [1, 2, 1, 4] -# CHECK-NEXT: sample 35: [1, 2, 1, 5] -# CHECK-NEXT: sample 36: [1, 2, 2, 0] -# CHECK-NEXT: sample 37: [1, 2, 2, 1] -# CHECK-NEXT: sample 38: [1, 2, 2, 2] -# CHECK-NEXT: sample 39: [1, 2, 2, 3] -# CHECK-NEXT: sample 40: [1, 2, 2, 4] -# CHECK-NEXT: sample 41: [1, 2, 2, 5] -# CHECK-NEXT: sample 42: [1, 2, 3, 0] -# CHECK-NEXT: sample 43: [1, 2, 3, 1] -# CHECK-NEXT: sample 44: [1, 2, 3, 2] -# CHECK-NEXT: sample 45: [1, 2, 3, 3] -# CHECK-NEXT: sample 46: [1, 2, 3, 4] -# CHECK-NEXT: sample 47: [1, 2, 3, 5] -# CHECK-NEXT: sample 48: [1, 2, 4, 0] -# CHECK-NEXT: sample 49: [1, 2, 4, 1] -# CHECK-NEXT: sample 50: [1, 2, 4, 2] -# CHECK-NEXT: sample 51: [1, 2, 4, 3] -# CHECK-NEXT: sample 52: [1, 2, 4, 4] -# CHECK-NEXT: sample 53: [1, 2, 4, 5] -# CHECK-NEXT: sample 54: [1, 2, 6, 1] -# CHECK-NEXT: sample 55: [1, 2, 6, 4] -# CHECK-NEXT: sample 56: [1, 4, 1, 0] -# CHECK-NEXT: sample 57: [1, 4, 1, 1] -# CHECK-NEXT: sample 58: [1, 4, 1, 2] -# CHECK-NEXT: sample 59: [1, 4, 1, 3] -# CHECK-NEXT: sample 60: [1, 4, 1, 4] -# CHECK-NEXT: sample 61: [1, 4, 1, 5] -# CHECK-NEXT: sample 62: [1, 4, 2, 0] -# CHECK-NEXT: sample 63: [1, 4, 2, 1] -# CHECK-NEXT: sample 64: [1, 4, 2, 2] -# CHECK-NEXT: sample 65: [1, 4, 2, 3] -# CHECK-NEXT: sample 66: [1, 4, 2, 4] -# CHECK-NEXT: sample 67: [1, 4, 2, 5] -# CHECK-NEXT: sample 68: [1, 4, 3, 1] -# CHECK-NEXT: sample 69: [1, 4, 3, 4] -# CHECK-NEXT: sample 70: [1, 4, 4, 1] -# CHECK-NEXT: sample 71: [1, 4, 4, 4] -# CHECK-NEXT: sample 72: [1, 4, 6, 1] -# CHECK-NEXT: sample 73: [1, 4, 6, 4] -# CHECK-NEXT: sample 74: [1, 8, 1, 0] -# CHECK-NEXT: sample 75: [1, 8, 1, 1] -# CHECK-NEXT: sample 76: [1, 8, 1, 2] -# CHECK-NEXT: sample 77: [1, 8, 1, 3] -# CHECK-NEXT: sample 78: [1, 8, 1, 4] -# CHECK-NEXT: sample 79: [1, 8, 1, 5] -# CHECK-NEXT: sample 80: [1, 8, 2, 1] -# CHECK-NEXT: sample 81: [1, 8, 2, 4] -# CHECK-NEXT: sample 82: [1, 8, 3, 1] -# CHECK-NEXT: sample 83: [1, 8, 3, 4] -# CHECK-NEXT: sample 84: [1, 8, 4, 1] -# CHECK-NEXT: sample 85: [1, 8, 4, 4] -# CHECK-NEXT: sample 86: [1, 8, 6, 1] -# CHECK-NEXT: sample 87: [1, 8, 6, 4] -# CHECK-NEXT: sample 88: [1, 16, 1, 1] -# CHECK-NEXT: sample 89: [1, 16, 1, 4] -# CHECK-NEXT: sample 90: [1, 16, 2, 1] -# CHECK-NEXT: sample 91: [1, 16, 2, 4] -# CHECK-NEXT: sample 92: [1, 16, 3, 1] -# CHECK-NEXT: sample 93: [1, 16, 3, 4] -# CHECK-NEXT: sample 94: [1, 16, 4, 1] -# CHECK-NEXT: sample 95: [1, 16, 4, 4] -# CHECK-NEXT: sample 96: [1, 16, 6, 1] -# CHECK-NEXT: sample 97: [1, 16, 6, 4] -# CHECK-NEXT: sample 98: [1, 32, 1, 1] -# CHECK-NEXT: sample 99: [1, 32, 1, 4] -# CHECK-NEXT: stats {'filtered': 100, 'all': 185} -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] +# CHECK: ['1 || kR || 12', '1 || jR || 32', '1 || iR || 21', '0 <= order_DDR <= 5', '0 <= order_R <= 5'] diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py index e968a4a29..fd97325fb 100644 --- a/tests/filecheck/search/test_matmul_descript_goto.py +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -30,118 +30,8 @@ "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"}}, } constraint = ["iR * jR <= 56"] -strategy = Strategy(graph, spec, constraints=constraint) +strategy = Strategy(graph, spec, constraints=constraint, initialize=False) -utils.print_all_opt_schedules(backend, strategy) -utils.print_exhaustive_samples(backend, strategy, 100) +print(strategy._constraints) -# CHECK: schedule O0: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O1: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O2: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: schedule O3: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: sample 0: [1, 1, 1, 0] -# CHECK-NEXT: sample 1: [1, 1, 1, 1] -# CHECK-NEXT: sample 2: [1, 1, 1, 2] -# CHECK-NEXT: sample 3: [1, 1, 1, 3] -# CHECK-NEXT: sample 4: [1, 1, 1, 4] -# CHECK-NEXT: sample 5: [1, 1, 1, 5] -# CHECK-NEXT: sample 6: [1, 1, 2, 0] -# CHECK-NEXT: sample 7: [1, 1, 2, 1] -# CHECK-NEXT: sample 8: [1, 1, 2, 2] -# CHECK-NEXT: sample 9: [1, 1, 2, 3] -# CHECK-NEXT: sample 10: [1, 1, 2, 4] -# CHECK-NEXT: sample 11: [1, 1, 2, 5] -# CHECK-NEXT: sample 12: [1, 1, 3, 0] -# CHECK-NEXT: sample 13: [1, 1, 3, 1] -# CHECK-NEXT: sample 14: [1, 1, 3, 2] -# CHECK-NEXT: sample 15: [1, 1, 3, 3] -# CHECK-NEXT: sample 16: [1, 1, 3, 4] -# CHECK-NEXT: sample 17: [1, 1, 3, 5] -# CHECK-NEXT: sample 18: [1, 1, 4, 0] -# CHECK-NEXT: sample 19: [1, 1, 4, 1] -# CHECK-NEXT: sample 20: [1, 1, 4, 2] -# CHECK-NEXT: sample 21: [1, 1, 4, 3] -# CHECK-NEXT: sample 22: [1, 1, 4, 4] -# CHECK-NEXT: sample 23: [1, 1, 4, 5] -# CHECK-NEXT: sample 24: [1, 1, 6, 0] -# CHECK-NEXT: sample 25: [1, 1, 6, 1] -# CHECK-NEXT: sample 26: [1, 1, 6, 2] -# CHECK-NEXT: sample 27: [1, 1, 6, 3] -# CHECK-NEXT: sample 28: [1, 1, 6, 4] -# CHECK-NEXT: sample 29: [1, 1, 6, 5] -# CHECK-NEXT: sample 30: [1, 2, 1, 0] -# CHECK-NEXT: sample 31: [1, 2, 1, 1] -# CHECK-NEXT: sample 32: [1, 2, 1, 2] -# CHECK-NEXT: sample 33: [1, 2, 1, 3] -# CHECK-NEXT: sample 34: [1, 2, 1, 4] -# CHECK-NEXT: sample 35: [1, 2, 1, 5] -# CHECK-NEXT: sample 36: [1, 2, 2, 0] -# CHECK-NEXT: sample 37: [1, 2, 2, 1] -# CHECK-NEXT: sample 38: [1, 2, 2, 2] -# CHECK-NEXT: sample 39: [1, 2, 2, 3] -# CHECK-NEXT: sample 40: [1, 2, 2, 4] -# CHECK-NEXT: sample 41: [1, 2, 2, 5] -# CHECK-NEXT: sample 42: [1, 2, 3, 0] -# CHECK-NEXT: sample 43: [1, 2, 3, 1] -# CHECK-NEXT: sample 44: [1, 2, 3, 2] -# CHECK-NEXT: sample 45: [1, 2, 3, 3] -# CHECK-NEXT: sample 46: [1, 2, 3, 4] -# CHECK-NEXT: sample 47: [1, 2, 3, 5] -# CHECK-NEXT: sample 48: [1, 2, 4, 0] -# CHECK-NEXT: sample 49: [1, 2, 4, 1] -# CHECK-NEXT: sample 50: [1, 2, 4, 2] -# CHECK-NEXT: sample 51: [1, 2, 4, 3] -# CHECK-NEXT: sample 52: [1, 2, 4, 4] -# CHECK-NEXT: sample 53: [1, 2, 4, 5] -# CHECK-NEXT: sample 54: [1, 2, 6, 1] -# CHECK-NEXT: sample 55: [1, 2, 6, 4] -# CHECK-NEXT: sample 56: [1, 4, 1, 0] -# CHECK-NEXT: sample 57: [1, 4, 1, 1] -# CHECK-NEXT: sample 58: [1, 4, 1, 2] -# CHECK-NEXT: sample 59: [1, 4, 1, 3] -# CHECK-NEXT: sample 60: [1, 4, 1, 4] -# CHECK-NEXT: sample 61: [1, 4, 1, 5] -# CHECK-NEXT: sample 62: [1, 4, 2, 0] -# CHECK-NEXT: sample 63: [1, 4, 2, 1] -# CHECK-NEXT: sample 64: [1, 4, 2, 2] -# CHECK-NEXT: sample 65: [1, 4, 2, 3] -# CHECK-NEXT: sample 66: [1, 4, 2, 4] -# CHECK-NEXT: sample 67: [1, 4, 2, 5] -# CHECK-NEXT: sample 68: [1, 4, 3, 1] -# CHECK-NEXT: sample 69: [1, 4, 3, 4] -# CHECK-NEXT: sample 70: [1, 4, 4, 1] -# CHECK-NEXT: sample 71: [1, 4, 4, 4] -# CHECK-NEXT: sample 72: [1, 4, 6, 1] -# CHECK-NEXT: sample 73: [1, 4, 6, 4] -# CHECK-NEXT: sample 74: [1, 8, 1, 0] -# CHECK-NEXT: sample 75: [1, 8, 1, 1] -# CHECK-NEXT: sample 76: [1, 8, 1, 2] -# CHECK-NEXT: sample 77: [1, 8, 1, 3] -# CHECK-NEXT: sample 78: [1, 8, 1, 4] -# CHECK-NEXT: sample 79: [1, 8, 1, 5] -# CHECK-NEXT: sample 80: [1, 8, 2, 1] -# CHECK-NEXT: sample 81: [1, 8, 2, 4] -# CHECK-NEXT: sample 82: [1, 8, 3, 1] -# CHECK-NEXT: sample 83: [1, 8, 3, 4] -# CHECK-NEXT: sample 84: [1, 8, 4, 1] -# CHECK-NEXT: sample 85: [1, 8, 4, 4] -# CHECK-NEXT: sample 86: [1, 8, 6, 1] -# CHECK-NEXT: sample 87: [1, 8, 6, 4] -# CHECK-NEXT: sample 88: [1, 16, 1, 1] -# CHECK-NEXT: sample 89: [1, 16, 1, 4] -# CHECK-NEXT: sample 90: [1, 16, 2, 1] -# CHECK-NEXT: sample 91: [1, 16, 2, 4] -# CHECK-NEXT: sample 92: [1, 16, 3, 1] -# CHECK-NEXT: sample 93: [1, 16, 3, 4] -# CHECK-NEXT: sample 94: [1, 16, 4, 1] -# CHECK-NEXT: sample 95: [1, 16, 4, 4] -# CHECK-NEXT: sample 96: [1, 16, 6, 1] -# CHECK-NEXT: sample 97: [1, 16, 6, 4] -# CHECK-NEXT: sample 98: [1, 32, 1, 1] -# CHECK-NEXT: sample 99: [1, 32, 1, 4] -# CHECK-NEXT: stats {'filtered': 100, 'all': 185} -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] +# CHECK: ['1 || k_unroll || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_B <= 1', '0 <= pack_A <= 1', '0 <= j_parallel <= 1', '0 <= j_vectorise <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] diff --git a/tests/filecheck/search/test_matmul_descript_simple.py b/tests/filecheck/search/test_matmul_descript_simple.py index 86d7e5ae1..6d715109a 100644 --- a/tests/filecheck/search/test_matmul_descript_simple.py +++ b/tests/filecheck/search/test_matmul_descript_simple.py @@ -20,118 +20,8 @@ }, "L1": {"j#j2": {}}, } -strategy = Strategy(graph, spec, max_unroll=8) +strategy = Strategy(graph, spec, initialize=False) -utils.print_all_opt_schedules(backend, strategy) -utils.print_exhaustive_samples(backend, strategy, 100) +print(strategy._constraints) -# CHECK: schedule O0: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O1: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O2: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: schedule O3: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: sample 0: [1, 1, 1, 0] -# CHECK-NEXT: sample 1: [1, 1, 1, 1] -# CHECK-NEXT: sample 2: [1, 1, 1, 2] -# CHECK-NEXT: sample 3: [1, 1, 1, 3] -# CHECK-NEXT: sample 4: [1, 1, 1, 4] -# CHECK-NEXT: sample 5: [1, 1, 1, 5] -# CHECK-NEXT: sample 6: [1, 1, 2, 0] -# CHECK-NEXT: sample 7: [1, 1, 2, 1] -# CHECK-NEXT: sample 8: [1, 1, 2, 2] -# CHECK-NEXT: sample 9: [1, 1, 2, 3] -# CHECK-NEXT: sample 10: [1, 1, 2, 4] -# CHECK-NEXT: sample 11: [1, 1, 2, 5] -# CHECK-NEXT: sample 12: [1, 1, 3, 0] -# CHECK-NEXT: sample 13: [1, 1, 3, 1] -# CHECK-NEXT: sample 14: [1, 1, 3, 2] -# CHECK-NEXT: sample 15: [1, 1, 3, 3] -# CHECK-NEXT: sample 16: [1, 1, 3, 4] -# CHECK-NEXT: sample 17: [1, 1, 3, 5] -# CHECK-NEXT: sample 18: [1, 1, 4, 0] -# CHECK-NEXT: sample 19: [1, 1, 4, 1] -# CHECK-NEXT: sample 20: [1, 1, 4, 2] -# CHECK-NEXT: sample 21: [1, 1, 4, 3] -# CHECK-NEXT: sample 22: [1, 1, 4, 4] -# CHECK-NEXT: sample 23: [1, 1, 4, 5] -# CHECK-NEXT: sample 24: [1, 1, 6, 0] -# CHECK-NEXT: sample 25: [1, 1, 6, 1] -# CHECK-NEXT: sample 26: [1, 1, 6, 2] -# CHECK-NEXT: sample 27: [1, 1, 6, 3] -# CHECK-NEXT: sample 28: [1, 1, 6, 4] -# CHECK-NEXT: sample 29: [1, 1, 6, 5] -# CHECK-NEXT: sample 30: [1, 2, 1, 0] -# CHECK-NEXT: sample 31: [1, 2, 1, 1] -# CHECK-NEXT: sample 32: [1, 2, 1, 2] -# CHECK-NEXT: sample 33: [1, 2, 1, 3] -# CHECK-NEXT: sample 34: [1, 2, 1, 4] -# CHECK-NEXT: sample 35: [1, 2, 1, 5] -# CHECK-NEXT: sample 36: [1, 2, 2, 0] -# CHECK-NEXT: sample 37: [1, 2, 2, 1] -# CHECK-NEXT: sample 38: [1, 2, 2, 2] -# CHECK-NEXT: sample 39: [1, 2, 2, 3] -# CHECK-NEXT: sample 40: [1, 2, 2, 4] -# CHECK-NEXT: sample 41: [1, 2, 2, 5] -# CHECK-NEXT: sample 42: [1, 2, 3, 0] -# CHECK-NEXT: sample 43: [1, 2, 3, 1] -# CHECK-NEXT: sample 44: [1, 2, 3, 2] -# CHECK-NEXT: sample 45: [1, 2, 3, 3] -# CHECK-NEXT: sample 46: [1, 2, 3, 4] -# CHECK-NEXT: sample 47: [1, 2, 3, 5] -# CHECK-NEXT: sample 48: [1, 2, 4, 0] -# CHECK-NEXT: sample 49: [1, 2, 4, 1] -# CHECK-NEXT: sample 50: [1, 2, 4, 2] -# CHECK-NEXT: sample 51: [1, 2, 4, 3] -# CHECK-NEXT: sample 52: [1, 2, 4, 4] -# CHECK-NEXT: sample 53: [1, 2, 4, 5] -# CHECK-NEXT: sample 54: [1, 2, 6, 1] -# CHECK-NEXT: sample 55: [1, 2, 6, 4] -# CHECK-NEXT: sample 56: [1, 4, 1, 0] -# CHECK-NEXT: sample 57: [1, 4, 1, 1] -# CHECK-NEXT: sample 58: [1, 4, 1, 2] -# CHECK-NEXT: sample 59: [1, 4, 1, 3] -# CHECK-NEXT: sample 60: [1, 4, 1, 4] -# CHECK-NEXT: sample 61: [1, 4, 1, 5] -# CHECK-NEXT: sample 62: [1, 4, 2, 0] -# CHECK-NEXT: sample 63: [1, 4, 2, 1] -# CHECK-NEXT: sample 64: [1, 4, 2, 2] -# CHECK-NEXT: sample 65: [1, 4, 2, 3] -# CHECK-NEXT: sample 66: [1, 4, 2, 4] -# CHECK-NEXT: sample 67: [1, 4, 2, 5] -# CHECK-NEXT: sample 68: [1, 4, 3, 1] -# CHECK-NEXT: sample 69: [1, 4, 3, 4] -# CHECK-NEXT: sample 70: [1, 4, 4, 1] -# CHECK-NEXT: sample 71: [1, 4, 4, 4] -# CHECK-NEXT: sample 72: [1, 4, 6, 1] -# CHECK-NEXT: sample 73: [1, 4, 6, 4] -# CHECK-NEXT: sample 74: [1, 8, 1, 0] -# CHECK-NEXT: sample 75: [1, 8, 1, 1] -# CHECK-NEXT: sample 76: [1, 8, 1, 2] -# CHECK-NEXT: sample 77: [1, 8, 1, 3] -# CHECK-NEXT: sample 78: [1, 8, 1, 4] -# CHECK-NEXT: sample 79: [1, 8, 1, 5] -# CHECK-NEXT: sample 80: [1, 8, 2, 1] -# CHECK-NEXT: sample 81: [1, 8, 2, 4] -# CHECK-NEXT: sample 82: [1, 8, 3, 1] -# CHECK-NEXT: sample 83: [1, 8, 3, 4] -# CHECK-NEXT: sample 84: [1, 8, 4, 1] -# CHECK-NEXT: sample 85: [1, 8, 4, 4] -# CHECK-NEXT: sample 86: [1, 8, 6, 1] -# CHECK-NEXT: sample 87: [1, 8, 6, 4] -# CHECK-NEXT: sample 88: [1, 16, 1, 1] -# CHECK-NEXT: sample 89: [1, 16, 1, 4] -# CHECK-NEXT: sample 90: [1, 16, 2, 1] -# CHECK-NEXT: sample 91: [1, 16, 2, 4] -# CHECK-NEXT: sample 92: [1, 16, 3, 1] -# CHECK-NEXT: sample 93: [1, 16, 3, 4] -# CHECK-NEXT: sample 94: [1, 16, 4, 1] -# CHECK-NEXT: sample 95: [1, 16, 4, 4] -# CHECK-NEXT: sample 96: [1, 16, 6, 1] -# CHECK-NEXT: sample 97: [1, 16, 6, 4] -# CHECK-NEXT: sample 98: [1, 32, 1, 1] -# CHECK-NEXT: sample 99: [1, 32, 1, 4] -# CHECK-NEXT: stats {'filtered': 100, 'all': 185} -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] +# CHECK: ['1 || j2 || j1 || 32', '1 || i1 || 21'] diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py index 7e446ac71..35d2d2ccf 100644 --- a/tests/filecheck/search/test_matmul_descript_split.py +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -1,6 +1,6 @@ # RUN: python %s 2>&1 | filecheck %s """ -Test strategy Goto on matmul +Test splits on matmul """ import utils @@ -12,135 +12,30 @@ "DDR": { "j": {}, "k": {}, + "i": {}, + }, + "L3": {"i#iL3": {}}, + "L2": { + "i#7": {}, }, "L1": { "j#jDDR": {}, - "i[:iT1]": { + "i[:5]": { "R": { "i#iR1": {"unroll": None}, - "j#jR": {"vectorize": None}, + "j#jR1": {"vectorize": None}, }, }, - "i[iT1:]": { + "i[5:]": { "R": { "i#iR2": {"unroll": None}, - "j#jR": {"vectorize": None}, + "j#jR2": {"vectorize": None}, }, }, }, } -strategy = Strategy(graph, spec, max_unroll=8) +strategy = Strategy(graph, spec, initialize=False) -utils.print_all_opt_schedules(backend, strategy) -utils.print_exhaustive_samples(backend, strategy, 100) +print(strategy._constraints) -# CHECK: schedule O0: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O1: [1, 1, 1, 0] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './j1', './k1']}, vectorization=[], parallelization=[], unrolling={'./k1': 1, './j1': 1, './i1': 1})] -# CHECK-NEXT: schedule O2: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: schedule O3: [1, 1, 1, 1] -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 1}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './i1', './k1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 1, './k1': 1, './i1': 1})] -# CHECK-NEXT: sample 0: [1, 1, 1, 0] -# CHECK-NEXT: sample 1: [1, 1, 1, 1] -# CHECK-NEXT: sample 2: [1, 1, 1, 2] -# CHECK-NEXT: sample 3: [1, 1, 1, 3] -# CHECK-NEXT: sample 4: [1, 1, 1, 4] -# CHECK-NEXT: sample 5: [1, 1, 1, 5] -# CHECK-NEXT: sample 6: [1, 1, 2, 0] -# CHECK-NEXT: sample 7: [1, 1, 2, 1] -# CHECK-NEXT: sample 8: [1, 1, 2, 2] -# CHECK-NEXT: sample 9: [1, 1, 2, 3] -# CHECK-NEXT: sample 10: [1, 1, 2, 4] -# CHECK-NEXT: sample 11: [1, 1, 2, 5] -# CHECK-NEXT: sample 12: [1, 1, 3, 0] -# CHECK-NEXT: sample 13: [1, 1, 3, 1] -# CHECK-NEXT: sample 14: [1, 1, 3, 2] -# CHECK-NEXT: sample 15: [1, 1, 3, 3] -# CHECK-NEXT: sample 16: [1, 1, 3, 4] -# CHECK-NEXT: sample 17: [1, 1, 3, 5] -# CHECK-NEXT: sample 18: [1, 1, 4, 0] -# CHECK-NEXT: sample 19: [1, 1, 4, 1] -# CHECK-NEXT: sample 20: [1, 1, 4, 2] -# CHECK-NEXT: sample 21: [1, 1, 4, 3] -# CHECK-NEXT: sample 22: [1, 1, 4, 4] -# CHECK-NEXT: sample 23: [1, 1, 4, 5] -# CHECK-NEXT: sample 24: [1, 1, 6, 0] -# CHECK-NEXT: sample 25: [1, 1, 6, 1] -# CHECK-NEXT: sample 26: [1, 1, 6, 2] -# CHECK-NEXT: sample 27: [1, 1, 6, 3] -# CHECK-NEXT: sample 28: [1, 1, 6, 4] -# CHECK-NEXT: sample 29: [1, 1, 6, 5] -# CHECK-NEXT: sample 30: [1, 2, 1, 0] -# CHECK-NEXT: sample 31: [1, 2, 1, 1] -# CHECK-NEXT: sample 32: [1, 2, 1, 2] -# CHECK-NEXT: sample 33: [1, 2, 1, 3] -# CHECK-NEXT: sample 34: [1, 2, 1, 4] -# CHECK-NEXT: sample 35: [1, 2, 1, 5] -# CHECK-NEXT: sample 36: [1, 2, 2, 0] -# CHECK-NEXT: sample 37: [1, 2, 2, 1] -# CHECK-NEXT: sample 38: [1, 2, 2, 2] -# CHECK-NEXT: sample 39: [1, 2, 2, 3] -# CHECK-NEXT: sample 40: [1, 2, 2, 4] -# CHECK-NEXT: sample 41: [1, 2, 2, 5] -# CHECK-NEXT: sample 42: [1, 2, 3, 0] -# CHECK-NEXT: sample 43: [1, 2, 3, 1] -# CHECK-NEXT: sample 44: [1, 2, 3, 2] -# CHECK-NEXT: sample 45: [1, 2, 3, 3] -# CHECK-NEXT: sample 46: [1, 2, 3, 4] -# CHECK-NEXT: sample 47: [1, 2, 3, 5] -# CHECK-NEXT: sample 48: [1, 2, 4, 0] -# CHECK-NEXT: sample 49: [1, 2, 4, 1] -# CHECK-NEXT: sample 50: [1, 2, 4, 2] -# CHECK-NEXT: sample 51: [1, 2, 4, 3] -# CHECK-NEXT: sample 52: [1, 2, 4, 4] -# CHECK-NEXT: sample 53: [1, 2, 4, 5] -# CHECK-NEXT: sample 54: [1, 2, 6, 1] -# CHECK-NEXT: sample 55: [1, 2, 6, 4] -# CHECK-NEXT: sample 56: [1, 4, 1, 0] -# CHECK-NEXT: sample 57: [1, 4, 1, 1] -# CHECK-NEXT: sample 58: [1, 4, 1, 2] -# CHECK-NEXT: sample 59: [1, 4, 1, 3] -# CHECK-NEXT: sample 60: [1, 4, 1, 4] -# CHECK-NEXT: sample 61: [1, 4, 1, 5] -# CHECK-NEXT: sample 62: [1, 4, 2, 0] -# CHECK-NEXT: sample 63: [1, 4, 2, 1] -# CHECK-NEXT: sample 64: [1, 4, 2, 2] -# CHECK-NEXT: sample 65: [1, 4, 2, 3] -# CHECK-NEXT: sample 66: [1, 4, 2, 4] -# CHECK-NEXT: sample 67: [1, 4, 2, 5] -# CHECK-NEXT: sample 68: [1, 4, 3, 1] -# CHECK-NEXT: sample 69: [1, 4, 3, 4] -# CHECK-NEXT: sample 70: [1, 4, 4, 1] -# CHECK-NEXT: sample 71: [1, 4, 4, 4] -# CHECK-NEXT: sample 72: [1, 4, 6, 1] -# CHECK-NEXT: sample 73: [1, 4, 6, 4] -# CHECK-NEXT: sample 74: [1, 8, 1, 0] -# CHECK-NEXT: sample 75: [1, 8, 1, 1] -# CHECK-NEXT: sample 76: [1, 8, 1, 2] -# CHECK-NEXT: sample 77: [1, 8, 1, 3] -# CHECK-NEXT: sample 78: [1, 8, 1, 4] -# CHECK-NEXT: sample 79: [1, 8, 1, 5] -# CHECK-NEXT: sample 80: [1, 8, 2, 1] -# CHECK-NEXT: sample 81: [1, 8, 2, 4] -# CHECK-NEXT: sample 82: [1, 8, 3, 1] -# CHECK-NEXT: sample 83: [1, 8, 3, 4] -# CHECK-NEXT: sample 84: [1, 8, 4, 1] -# CHECK-NEXT: sample 85: [1, 8, 4, 4] -# CHECK-NEXT: sample 86: [1, 8, 6, 1] -# CHECK-NEXT: sample 87: [1, 8, 6, 4] -# CHECK-NEXT: sample 88: [1, 16, 1, 1] -# CHECK-NEXT: sample 89: [1, 16, 1, 4] -# CHECK-NEXT: sample 90: [1, 16, 2, 1] -# CHECK-NEXT: sample 91: [1, 16, 2, 4] -# CHECK-NEXT: sample 92: [1, 16, 3, 1] -# CHECK-NEXT: sample 93: [1, 16, 3, 4] -# CHECK-NEXT: sample 94: [1, 16, 4, 1] -# CHECK-NEXT: sample 95: [1, 16, 4, 4] -# CHECK-NEXT: sample 96: [1, 16, 6, 1] -# CHECK-NEXT: sample 97: [1, 16, 6, 4] -# CHECK-NEXT: sample 98: [1, 32, 1, 1] -# CHECK-NEXT: sample 99: [1, 32, 1, 4] -# CHECK-NEXT: stats {'filtered': 100, 'all': 185} -# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['i', 'j'], loop_stamps=[], splits={}, tiles={'i': {}, 'j': {}}, permutation={'.': ['./i', './j']}, vectorization=[], parallelization=[], unrolling={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['i', 'j', 'k'], loop_stamps=[], splits={}, tiles={'i': {'./i1': 1}, 'j': {'./j1': 32}, 'k': {'./k1': 1}}, permutation={'.': ['./i', './j', './k', './k1', './i1', './j1']}, vectorization=['./j1'], parallelization=[], unrolling={'./j1': 32, './i1': 1, './k1': 1})] +# CHECK: ['1 || jR2 || jDDR || 32', '1 || jR1 || jDDR || 32', '1 || iR2 || 2', '1 || iR1 || 5', '7 || iL3 || 21'] From 020f8be20c299eeb02077cedeecd0533bb10b1e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Wed, 22 Oct 2025 14:11:06 +0200 Subject: [PATCH 06/23] More edits for supporting splits --- src/xtc/schedules/descript_extend.py | 105 ++++++++---------- .../search/test_matmul_descript_split.py | 2 +- .../test_matmul_descript_split_in_split.py | 46 ++++++++ 3 files changed, 93 insertions(+), 60 deletions(-) create mode 100644 tests/filecheck/search/test_matmul_descript_split_in_split.py diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index cdd25295d..39118f71f 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -73,15 +73,13 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict]): i = 0 while i < len(all_axis_constraints): sched = all_axis_constraints[i] - # print(all_axis_constraints, i, sched) - # print(sched) if isinstance(sched[0], int): axis_constraints.append(sched) all_axis_constraints.pop(i) else: i += 1 - # print(axis, axis_constraints, all_axis_constraints) - while len(all_axis_constraints) > 0: + flag_flag = True + while len(all_axis_constraints) > 0 and flag_flag: i = 0 axis_constraints_acc = [] flag_flag = False @@ -99,7 +97,9 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict]): i += 1 if flag_flag: axis_constraints = axis_constraints_acc - # print(axis, axis_constraints, all_axis_constraints) + + axis_constraints += all_axis_constraints + axis_constraints.reverse() for constraint in axis_constraints: constraint.reverse() constraint_str = "" @@ -113,25 +113,8 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict]): if var_flag: constraints.insert(0, constraint_str) - # flag = False - # for constraint in axis_constraints: - # if sched[0] == constraint[-1]: - # axis_constraints_acc.append(constraint + sched[1:]) - # flag = True - # else: - # axis_constraints_acc.append(constraint) - # axis_constraints = axis_constraints_acc - # if flag: - # all_axis_constraints.pop(i) - # i = 0 - # else: - # i += 1 - # print(all_axis_constraints, i) - # print(axis, axis_constraints) - variables = list(dict.fromkeys(variables)) constraints = list(dict.fromkeys(constraints)) - # print(constraints) return (flat_schedules, variables, constraints, axes, orders) def apply_sample( @@ -262,7 +245,6 @@ def _flatten_schedule( axes_sizes: dict[str, int | str] = tile_sizes else: axes_sizes = {a: v for a, v in self.abstract_axis_sizes.items()} - # print(axes_sizes) sched_sizes = {} for a, v in axes_sizes.items(): sched["sizes"][a] = [] @@ -317,7 +299,7 @@ def _flatten_schedule( sched["axis_orders"].append(tree_declaration) continue elif ":" in declaration: - axis_name, x, y = self.parse_split_declaration(declaration) + axis_name, x, y, z = self.parse_split_declaration(declaration) self._check_axis_existence(axis_name) # The only declaration where y (the cut) is None is the @@ -329,8 +311,8 @@ def _flatten_schedule( if x is None: x = cut - # print(declaration, axis_name, cut, x, y) - lam, inner_size = self._extended_check_splitting_intervals( + # assert isinstance(x, int) + inner_size = self._extended_check_splitting_intervals( declaration, axis_name, cut, x, y ) current_size = axes_sizes[axis_name] @@ -347,17 +329,30 @@ def _flatten_schedule( tree_interchange[axis_name].append(new_dim_name) else: tree_interchange[axis_name] = [new_dim_name] - inner_size = ( - inner_size if inner_size else eval(f"{current_size} - {x}") - ) - axes_sizes[axis_name] = inner_size - # sched["sizes"][axis_name].append(inner_size) - if lam: - if isinstance(y, str): - variables.append(y) - constraints.append(lam) - # constraints.append(f"1 || {y} || {current_size}") + inner_size = None + if y is None: + y = current_size + if isinstance(x, int): + if x == 0: + inner_size = y + elif isinstance(y, int): + inner_size = y - x + + if inner_size is None: + inner_size = root[1:] + new_dim_name + inner_size = ( + inner_size.replace("/", "") + .replace("[", "_") + .replace("]", "_") + ) + if isinstance(x, str): + constraints.append(f"{x} < {y}") + # constraints.append(f"1 || {x} || {y}") + # sched_sizes[axis_name].append(x) + constraints.append(f"{inner_size} + {x} == {y}") + + axes_sizes[axis_name] = inner_size # Fetch the schedule associated with the new dimension next_schedule = val @@ -370,10 +365,6 @@ def _flatten_schedule( ) axes_sizes[axis_name] = current_size - # for a, v in inner_scheds[0]["sizes"].items(): - # if a != axis_name: - # inner_scheds[0]["sizes"][a] = {} - # sched["sizes"][a] += v recursive_scheds += inner_scheds continue elif "#" in declaration: @@ -385,9 +376,6 @@ def _flatten_schedule( else: loop_size = tile_size variables.append(tile_size) - # constraints.append( - # f"1 || {tile_size} || {axes_sizes[axis_name]}" - # ) if not loop_size: raise Exception( f"Invalid tile size: '{tile_size}' in {declaration}" @@ -453,7 +441,6 @@ def _flatten_schedule( sched["interchange"] = interchange sched["variables"] = variables + sched["variables"] sched["constraints"] = constraints + sched["constraints"] - # print(sched_sizes, sched["sizes"]) for a in self.abstract_axis: flag = True for sched_ in sched["sizes"][a]: @@ -462,7 +449,6 @@ def _flatten_schedule( break if flag: sched["sizes"][a] = [sched_sizes[a]] + sched["sizes"][a] - # print(sched["sizes"]) return [sched] + recursive_scheds def _extended_check_splitting_intervals( @@ -472,7 +458,7 @@ def _extended_check_splitting_intervals( cut: int | str | None, x: int | str | None, y: int | str | None, - ) -> Tuple[str | None, int | str | None]: + ) -> int | str | None: if cut is None: raise Exception( f""" @@ -506,11 +492,10 @@ def _extended_check_splitting_intervals( ({cut} and {x} on axis {axis_name}) """ ) - + assert x == cut if y is None: - return (None, None) + return None - # constraint = f"{x} < {y}" if isinstance(x, int): if isinstance(y, int): if x >= y: @@ -521,14 +506,10 @@ def _extended_check_splitting_intervals( """ ) else: - return (None, y - x) - raise Exception(f""" - Arguments for the split must be ints for now. - ({x} or {y} on axis {axis_name}) - """) - # if x == 0: - # return (constraint, f"{y}") - # return (constraint, f"{y} - {x}") + return y - x + if x == 0: + return y + return None def annotate( self, @@ -585,10 +566,16 @@ def annotate( def parse_split_declaration( self, declaration: str, - ) -> Tuple[str, int | str | None, int | str | None]: + ) -> Tuple[str, int | str | None, int | str | None, int | str | None]: pattern = r"^(.*)\[(?:(-\w+|\w*)?):(?:(-\w+|\w*)?)\]$" match = re.match(pattern, declaration) if not match: + pattern = r"^(.*)\[:(\w*):]" + match = re.match(pattern, declaration) + if match: + prefix, z = match.group() + z = int(z) if z.isnumeric() else z + return prefix, None, None, z raise Exception(f"Wrong format {declaration}") prefix, x_str, y_str = match.groups() @@ -596,4 +583,4 @@ def parse_split_declaration( y = int(y_str) if y_str.isnumeric() else y_str x = x if x else None y = y if y else None - return prefix, x, y + return prefix, x, y, None diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py index 35d2d2ccf..fd5745802 100644 --- a/tests/filecheck/search/test_matmul_descript_split.py +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -38,4 +38,4 @@ print(strategy._constraints) -# CHECK: ['1 || jR2 || jDDR || 32', '1 || jR1 || jDDR || 32', '1 || iR2 || 2', '1 || iR1 || 5', '7 || iL3 || 21'] +# CHECK: ['1 || jR1 || jDDR || 32', '1 || jR3 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || 21', '1 || iR1 || 2', '1 || iR3 || 1', '1 || iR2 || i_1_', 'i_1_ + 6 == iL2'] diff --git a/tests/filecheck/search/test_matmul_descript_split_in_split.py b/tests/filecheck/search/test_matmul_descript_split_in_split.py new file mode 100644 index 000000000..d7494f772 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_split_in_split.py @@ -0,0 +1,46 @@ +# RUN: python -O %s 2>&1 | filecheck %s +""" +Test splits on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "DDR": { + "j": {}, + "k": {}, + "i": {}, + }, + "L3": { + "i#iL2": {}, + }, + "L2": { + "j#jDDR": {}, + "i[:6]": { + "L1": {"i#3": {}}, + "R": { + "i[:2]": { + "RR": { + "i#iR1": {"unroll": None}, + "j#jR1": {"vectorize": None}, + } + }, + "i[2:]": {"RR": {"i#iR3": {}, "j#jR3": {}}}, + }, + }, + "i[6:]": { + "R": { + "i#iR2": {"unroll": None}, + "j#jR2": {"vectorize": None}, + }, + }, + }, +} +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['1 || jR2 || jDDR || 32', '1 || jR1 || jDDR || 32', '1 || iR2 || 2', '1 || iR1 || 5', '7 || iL3 || 21'] From 56f0a09339b0b37bc057b8acbb488167755d9880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Thu, 23 Oct 2025 16:31:58 +0200 Subject: [PATCH 07/23] Yet another fix for splits --- src/xtc/schedules/descript_extend.py | 116 +++++++++++------- .../test_matmul_descript_split_in_split.py | 4 +- 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 39118f71f..8c73912a2 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -261,6 +261,7 @@ def _flatten_schedule( tree_packs = [] tree_fusion = [] tree_buff = [] + last_split = None for declaration, val in tree_val.items(): if declaration == "fusion": tree_fusion.append(val) @@ -306,51 +307,73 @@ def _flatten_schedule( # last one, so it cannot be the previous one. cut = previous_cut[axis_name] - # When x (the starting point of the slice), is not - # specified, it is the previous cut - if x is None: - x = cut - - # assert isinstance(x, int) - inner_size = self._extended_check_splitting_intervals( - declaration, axis_name, cut, x, y - ) current_size = axes_sizes[axis_name] # Update the previous cut - previous_cut[axis_name] = y # Save the cutting points of the new dimensions if axis_name not in sched["splits"]: sched["splits"][axis_name] = {} new_dim_index = len(sched["splits"][axis_name]) new_dim_name = f"{axis_name}[{new_dim_index}]" new_axes_root_name = f"{root}/{new_dim_name}" - sched["splits"][axis_name][new_dim_name] = x if axis_name in tree_interchange: tree_interchange[axis_name].append(new_dim_name) else: tree_interchange[axis_name] = [new_dim_name] - inner_size = None - if y is None: - y = current_size - if isinstance(x, int): - if x == 0: - inner_size = y - elif isinstance(y, int): - inner_size = y - x - - if inner_size is None: - inner_size = root[1:] + new_dim_name - inner_size = ( - inner_size.replace("/", "") - .replace("[", "_") - .replace("]", "_") + if z is None: + previous_cut[axis_name] = y + sched["splits"][axis_name][new_dim_name] = x + # When x (the starting point of the slice), is not + # specified, it is the previous cut + if x is None: + x = cut + + # assert isinstance(x, int) + inner_size = self._extended_check_splitting_intervals( + declaration, axis_name, cut, x, y ) - if isinstance(x, str): - constraints.append(f"{x} < {y}") - # constraints.append(f"1 || {x} || {y}") - # sched_sizes[axis_name].append(x) - constraints.append(f"{inner_size} + {x} == {y}") + inner_size = None + if y is None: + y = current_size + if isinstance(x, int): + if x == 0: + inner_size = y + elif isinstance(y, int): + inner_size = y - x + + if inner_size is None: + inner_size = root[1:] + new_dim_name + inner_size = ( + inner_size.replace("/", "") + .replace("[", "_") + .replace("]", "_") + ) + if isinstance(x, str): + constraints.append(f"{x} < {y}") + # constraints.append(f"1 || {x} || {y}") + # sched_sizes[axis_name].append(x) + constraints.append(f"{inner_size} + {x} == {y}") + else: + inner_size = z + x = cut + y = current_size + if isinstance(z, int) and isinstance(x, int): + previous_cut[axis_name] = x + z + if not isinstance(y, int): + constraints.append(f"{z + x} <= {y}") + else: + new_cut = root[1:] + new_dim_name + new_cut = ( + new_cut.replace("/", "") + .replace("[", "_") + .replace("]", "_") + ) + previous_cut[axis_name] = new_cut + if last_split is not None: + a, b = last_split + constraints.append(f"{a} <= {b}") + last_split = (new_cut, y) + constraints.append(f"{z} + {x} == {new_cut}") axes_sizes[axis_name] = inner_size @@ -428,12 +451,21 @@ def _flatten_schedule( for v in tree_interchange.values(): interchange += v - # Check if the last cut of each axis is either 0 or None. - # None correspond to "until the end of the loop". 0 is the - # default value, if it has 0 then it means the axis isn't splitted. - # Any other value means the split is let in a partial state. + # Check if the last cut of each axis is either 0 or None. + # None correspond to "until the end of the loop". 0 is the + # default value, if it has 0 then it means the axis isn't splitted. + # Any other value means the split is let in a partial state. + if last_split is not None: + a, b = last_split + if isinstance(a, int): + constraints.append(f"{b} in {{{a}}}") + elif isinstance(b, int): + constraints.append(f"{a} in {{{b}}}") + else: + constraints.append(f"{b} == {a}") + last_split = None for axis, cut in previous_cut.items(): - if cut is not None and cut != 0: + if cut is not None and isinstance(cut, int) and cut != 0: raise Exception( f"Splitting on axis {axis} should end but stops at {cut}" ) @@ -570,13 +602,13 @@ def parse_split_declaration( pattern = r"^(.*)\[(?:(-\w+|\w*)?):(?:(-\w+|\w*)?)\]$" match = re.match(pattern, declaration) if not match: - pattern = r"^(.*)\[:(\w*):]" + pattern = r"^(.*)\[:(\w*):\]$" match = re.match(pattern, declaration) - if match: - prefix, z = match.group() - z = int(z) if z.isnumeric() else z - return prefix, None, None, z - raise Exception(f"Wrong format {declaration}") + if not match: + raise Exception(f"Wrong format {declaration}") + prefix, z = match.groups() + z = int(z) if z.isnumeric() else z + return prefix, None, None, z prefix, x_str, y_str = match.groups() x = int(x_str) if x_str.isnumeric() else x_str diff --git a/tests/filecheck/search/test_matmul_descript_split_in_split.py b/tests/filecheck/search/test_matmul_descript_split_in_split.py index d7494f772..e397cc9ef 100644 --- a/tests/filecheck/search/test_matmul_descript_split_in_split.py +++ b/tests/filecheck/search/test_matmul_descript_split_in_split.py @@ -22,13 +22,13 @@ "i[:6]": { "L1": {"i#3": {}}, "R": { - "i[:2]": { + "i[:2:]": { "RR": { "i#iR1": {"unroll": None}, "j#jR1": {"vectorize": None}, } }, - "i[2:]": {"RR": {"i#iR3": {}, "j#jR3": {}}}, + "i[:iS:]": {"RR": {"i#iR3": {}, "j#jR3": {}}}, }, }, "i[6:]": { From b5a4e0ddac99ed25a603734a28bc1679720b40fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Thu, 23 Oct 2025 16:41:44 +0200 Subject: [PATCH 08/23] Supporting abstract matrix names for packing --- src/xtc/schedules/descript_extend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 8c73912a2..0b46402db 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -18,10 +18,13 @@ def descript_extend_scheduler( abstract_axis: list[str], abstract_axis_sizes: dict[str, int], spec: dict[str, dict], + abstract_matrix: list[str] = [], sample: dict[str, Any] = {}, ): descript = DescriptExtend( - abstract_axis=abstract_axis, abstract_axis_sizes=abstract_axis_sizes + abstract_axis=abstract_axis, + abstract_axis_sizes=abstract_axis_sizes, + abstract_matrix=abstract_matrix, ) descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) @@ -29,6 +32,7 @@ def descript_extend_scheduler( @dataclass(frozen=True) class DescriptExtend(Descript): abstract_axis_sizes: dict[str, int] + abstract_matrix: list[str] = [] @override def apply( @@ -276,7 +280,7 @@ def _flatten_schedule( variables.append(param) constraints.append(f"0 <= {param} <= 1") if isinstance(input, str): - raise Exception("Packing input cannot be a variable.") + input = self.abstract_matrix.index(input) if isinstance(pad, str): variables.append(pad) constraints.append(f"0 <= {pad} <= 1") From a7ab203d1a064f9b34466a5273c92334e6386628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Thu, 23 Oct 2025 17:00:09 +0200 Subject: [PATCH 09/23] Adding alternative syntax for packing/bufferization --- src/xtc/schedules/descript_extend.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 0b46402db..0cbfefe2a 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -303,6 +303,22 @@ def _flatten_schedule( elif declaration == "explore_axis_order": sched["axis_orders"].append(tree_declaration) continue + elif declaration in self.abstract_matrix: + matrix_index = self.abstract_matrix.index(declaration) + param = val.get("bufferize", False) + pad = val.get("pad", False) + if param is None or param: + if matrix_index == len(self.abstract_matrix) - 1: + tree_buff.append((param, pad)) + else: + tree_packs.append((param, matrix_index, pad)) + if isinstance(param, str): + variables.append(param) + constraints.append(f"0 <= {param} <= 1") + if isinstance(pad, str): + variables.append(pad) + constraints.append(f"0 <= {pad} <= 1") + continue elif ":" in declaration: axis_name, x, y, z = self.parse_split_declaration(declaration) self._check_axis_existence(axis_name) From 387cbae59425f8ff1f30429995136dd14a509037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Fri, 24 Oct 2025 11:59:17 +0200 Subject: [PATCH 10/23] YAML parsing --- requirements.txt | 1 + src/xtc/schedules/descript_extend.py | 90 +++++++++++++++++-- src/xtc/search/strategies.py | 6 +- .../search/test_matmul_descript_split.py | 4 +- .../search/test_matmul_descript_yaml_goto.py | 36 ++++++++ .../test_matmul_descript_yaml_simple.py | 26 ++++++ .../search/test_matmul_descript_yaml_split.py | 37 ++++++++ 7 files changed, 191 insertions(+), 9 deletions(-) create mode 100644 tests/filecheck/search/test_matmul_descript_yaml_goto.py create mode 100644 tests/filecheck/search/test_matmul_descript_yaml_simple.py create mode 100644 tests/filecheck/search/test_matmul_descript_yaml_split.py diff --git a/requirements.txt b/requirements.txt index 91f4efa67..ed218fc34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyyaml scikit-learn networkx sympy +strictyaml diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 0cbfefe2a..18d76a15e 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -5,6 +5,7 @@ from typing import Any, Tuple from dataclasses import dataclass import re +import strictyaml from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler @@ -32,17 +33,21 @@ def descript_extend_scheduler( @dataclass(frozen=True) class DescriptExtend(Descript): abstract_axis_sizes: dict[str, int] - abstract_matrix: list[str] = [] + abstract_matrix: list[str] @override def apply( self, node_name: str, - spec: dict[str, dict], + spec: dict[str, dict] | str, scheduler: Scheduler, sample: dict[str, Any] = {}, ): - flat_schedules = self._flatten_schedule(root=node_name, spec=spec, head=[]) + if isinstance(spec, str): + dict_spec = self.parse_yaml(spec) + else: + dict_spec = spec + flat_schedules = self._flatten_schedule(root=node_name, spec=dict_spec, head=[]) variables = set() constraints = set() for schedule in flat_schedules: @@ -52,8 +57,83 @@ def apply( flat_schedules = self.apply_sample(flat_schedules, sample) self.apply_scheduler(flat_schedules, scheduler) - def flatten_schedule(self, node_name: str, spec: dict[str, dict]): - flat_schedules = self._flatten_schedule(root=node_name, spec=spec, head=[]) + def parse_yaml(self, spec: str) -> dict[str, dict]: + dspec = strictyaml.load(spec).data + assert isinstance(dspec, dict) + return self._parse_yaml(dspec) + + def _parse_yaml(self, spec: dict[str, dict]) -> dict[str, dict]: + out_dict = {} + for level, d_level in spec.items(): + level_dict = {} + for a, v in d_level.items(): + if a == "explore": + assert isinstance(v, str) + if v == "": + tmp = None + else: + try: + tmp = eval(v) + except NameError: + tmp = v + level_dict["explore_axis_order"] = tmp + elif a in self.abstract_matrix: + assert isinstance(v, str) + level_dict[a] = self._split_yaml(v) + else: + if isinstance(v, str): + d = self._split_yaml(v) + else: + assert isinstance(v, dict) + d = v + size = d.get("size", None) + if size: + # if d.get("split", None): + # raise Exception(f""" + # Axis cannot be tiled and split. + # (Axis {a} at level {level}.) + # """) + d.pop("size") + a = f"{a}#{size}" + # split = d.get("split", None) + # if split: + # if isinstance(split, str): + # d_split = self._split_yaml(split) + # else: + # d_split = split + # for sub_level, sub_range in d_split.items(): + # sub_dict = self._parse_yaml({sub_level : d[sub_level]}) + # level_dict[a + sub_range] = sub_dict + # continue + if ":" in a: + level_dict[a] = self._parse_yaml(d) + continue + level_dict[a] = {} + for axis_arg, arg_val in d.items(): + level_dict[a][axis_arg] = arg_val + out_dict[level] = level_dict + return out_dict + + def _split_yaml(self, s: str) -> dict[str, Any]: + d = {} + for s in s.split(): + if "=" not in s: + d[s] = None + else: + x, y = s.split("=") + try: + tmp = eval(y) + except (NameError, SyntaxError): + tmp = y + d[x] = tmp + return d + + def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): + if isinstance(spec, str): + dict_spec = self.parse_yaml(spec) + else: + dict_spec = spec + flat_schedules = self._flatten_schedule(root=node_name, spec=dict_spec, head=[]) variables = [] constraints = [] axes = {} diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 3cef6cda3..2307ef2a4 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -949,7 +949,7 @@ class Strategy_Descript(Strategy): def __init__( self, graph: Graph, - spec: dict[str, dict], + spec: dict[str, dict] | str, constraints: list[str] = [], initialize: bool = True, ) -> None: @@ -959,7 +959,9 @@ def __init__( self._axes = list(self._op.dims) self._sizes = self._constant_sizes() descript = DescriptExtend( - abstract_axis=self._axes, abstract_axis_sizes=dict(self._sizes) + abstract_axis=self._axes, + abstract_axis_sizes=dict(self._sizes), + abstract_matrix=["A", "B", "C"], ) self._descript = descript self._initialized = False diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py index fd5745802..703b30ac3 100644 --- a/tests/filecheck/search/test_matmul_descript_split.py +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -21,13 +21,13 @@ "L1": { "j#jDDR": {}, "i[:5]": { - "R": { + "R1": { "i#iR1": {"unroll": None}, "j#jR1": {"vectorize": None}, }, }, "i[5:]": { - "R": { + "R2": { "i#iR2": {"unroll": None}, "j#jR2": {"vectorize": None}, }, diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py new file mode 100644 index 000000000..71caa2df9 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -0,0 +1,36 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = """ +DDRj: + j: + parallelize: j_par +DDR: + k: + i: + explore: True + A: bufferize=pack_A pad + B: bufferize=pack_B pad +L3: + j: size=jL3 +L2: + i: size=iL2 +L1: + k: size=kL1 unroll=kU +R: + i: size=iR unroll + j: size=jR vectorize=jV +""" +constraint = ["iR * jR <= 56"] +strategy = Strategy(graph, spec, constraints=constraint, initialize=False) + +print(strategy._constraints) + +# CHECK: ['1 || k_unroll || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_B <= 1', '0 <= pack_A <= 1', '0 <= j_parallel <= 1', '0 <= j_vectorise <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] diff --git a/tests/filecheck/search/test_matmul_descript_yaml_simple.py b/tests/filecheck/search/test_matmul_descript_yaml_simple.py new file mode 100644 index 000000000..f56e1d9df --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_simple.py @@ -0,0 +1,26 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = """ +L3: + k: + i: + j: +L2: + i#i1: + j#j1: +L1: + j#j2: +""" +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['1 || j2 || j1 || 32', '1 || i1 || 21'] diff --git a/tests/filecheck/search/test_matmul_descript_yaml_split.py b/tests/filecheck/search/test_matmul_descript_yaml_split.py new file mode 100644 index 000000000..ede7ac7a6 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_split.py @@ -0,0 +1,37 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test splits on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = """ +DDR: + j: + k: + i: +L3: + i#iL3: +L2: + i#7: +L1: + j#jDDR: + i[:5]: + R1: + i#iR1: unroll + j#jR1: vectorize + SR1: + k#SR: + i[5:]: + R2: + i#iR2: unroll + j#jR2: unroll +""" +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['1 || jR1 || jDDR || 32', '1 || jR3 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || 21', '1 || iR1 || 2', '1 || iR3 || 1', '1 || iR2 || i_1_', 'i_1_ + 6 == iL2'] From a6dc1e0939b0dc930118e5b8ca4f71347268af2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Fri, 31 Oct 2025 12:19:14 +0100 Subject: [PATCH 11/23] Updating sampler --- src/xtc/schedules/descript_extend.py | 12 +++++++++-- src/xtc/search/strategies.py | 21 +++++++++++++++---- .../search/test_matmul_descript_yaml_goto.py | 8 ++++--- .../test_matmul_descript_yaml_simple.py | 6 ++++-- .../search/test_matmul_descript_yaml_split.py | 14 +++++++------ 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 18d76a15e..915a52a79 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -449,7 +449,7 @@ def _flatten_schedule( .replace("]", "_") ) if isinstance(x, str): - constraints.append(f"{x} < {y}") + constraints.append(f"{x} <= {y}") # constraints.append(f"1 || {x} || {y}") # sched_sizes[axis_name].append(x) constraints.append(f"{inner_size} + {x} == {y}") @@ -461,6 +461,10 @@ def _flatten_schedule( previous_cut[axis_name] = x + z if not isinstance(y, int): constraints.append(f"{z + x} <= {y}") + elif isinstance(x, int) and x == 0: + previous_cut[axis_name] = z + if not isinstance(y, int): + constraints.append(f"{z} <= {y}") else: new_cut = root[1:] + new_dim_name new_cut = ( @@ -562,7 +566,11 @@ def _flatten_schedule( elif isinstance(b, int): constraints.append(f"{a} in {{{b}}}") else: - constraints.append(f"{b} == {a}") + for i in range(len(constraints)): + c = constraints[i] + constraints[i] = c.replace(a, b) + # constraints.remove(c) c.replace() + # constraints.append(f"{b} == {a}") last_split = None for axis, cut in previous_cut.items(): if cut is not None and isinstance(cut, int) and cut != 0: diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 2307ef2a4..cff5a5973 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -991,14 +991,17 @@ def __init__( def _initialize(self): if self._initialized: return + max_enum = max(self._sizes.values()) constraints = constraints_from_str(self._constraints, silent=True) - properties, constraints = hypergraph(constraints, silent=True) + properties, constraints = hypergraph( + constraints, max_enum=max_enum, silent=True + ) methods = solve_with_z3( sampler_variables.keys(), properties, constraints, silent=True ) enumerations = execute_static(methods, properties, constraints, silent=True) self._properties = properties - self._constraints = constraints + self._z3_constraints = constraints self._methods = methods self._enumerations = enumerations self._initialized = True @@ -1023,16 +1026,26 @@ def generate(self, scheduler: Scheduler, sample: Sample) -> None: @override def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: + samples = sample_uniques(self._sample_once_tuple, num) + for x in samples: + yield dict(zip(self.sample_names, x)) + + def sample_once(self, num: int) -> Iterator[Sample]: self._initialize() draw = execute_dynamic( self._methods, self._properties, - self._constraints, + self._z3_constraints, self._enumerations, k=num, silent=True, ) - return iter(list(draw.values())[0]) + return draw + + def _sample_once_tuple(self, num: int) -> Iterator[tuple]: + draw = self.sample_once(num) + for d in draw: + yield tuple(d.values()) @override def exhaustive(self) -> Iterator[Sample]: diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py index 71caa2df9..bf37c01fb 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_goto.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -1,4 +1,4 @@ -# RUN: python %s 2>&1 | filecheck %s +# RUN: python -O %s 2>&1 | filecheck %s """ Test strategy Goto on matmul """ @@ -29,8 +29,10 @@ j: size=jR vectorize=jV """ constraint = ["iR * jR <= 56"] -strategy = Strategy(graph, spec, constraints=constraint, initialize=False) +strategy = Strategy(graph, spec, constraints=constraint) print(strategy._constraints) +print(len(list(strategy.sample(100)))) -# CHECK: ['1 || k_unroll || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_B <= 1', '0 <= pack_A <= 1', '0 <= j_parallel <= 1', '0 <= j_vectorise <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] +# CHECK: ['1 || kU || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_A <= 1', '0 <= pack_B <= 1', '0 <= j_par <= 1', '0 <= jV <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] +#CHECK-NEXT: 100 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_simple.py b/tests/filecheck/search/test_matmul_descript_yaml_simple.py index f56e1d9df..9112a6321 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_simple.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_simple.py @@ -1,4 +1,4 @@ -# RUN: python %s 2>&1 | filecheck %s +# RUN: python -O %s 2>&1 | filecheck %s """ Test strategy Goto on matmul """ @@ -19,8 +19,10 @@ L1: j#j2: """ -strategy = Strategy(graph, spec, initialize=False) +strategy = Strategy(graph, spec) print(strategy._constraints) +print(len(list(strategy.sample(100)))) # CHECK: ['1 || j2 || j1 || 32', '1 || i1 || 21'] +# CHECK-NEXT: 84 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_split.py b/tests/filecheck/search/test_matmul_descript_yaml_split.py index ede7ac7a6..c054cd4e6 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_split.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_split.py @@ -1,4 +1,4 @@ -# RUN: python %s 2>&1 | filecheck %s +# RUN: python %s -O 2>&1 | filecheck %s """ Test splits on matmul """ @@ -16,22 +16,24 @@ L3: i#iL3: L2: - i#7: + i#iL2: L1: j#jDDR: - i[:5]: + i[:iS]: R1: i#iR1: unroll j#jR1: vectorize SR1: k#SR: - i[5:]: + i[iS:]: R2: i#iR2: unroll j#jR2: unroll """ -strategy = Strategy(graph, spec, initialize=False) +strategy = Strategy(graph, spec) print(strategy._constraints) +print(len(list(strategy.sample(100)))) -# CHECK: ['1 || jR1 || jDDR || 32', '1 || jR3 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || 21', '1 || iR1 || 2', '1 || iR3 || 1', '1 || iR2 || i_1_', 'i_1_ + 6 == iL2'] +# CHECK: ['1 || SR || 12', '1 || jR1 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || iL3 || 21', '1 || iR1 || iS', '1 || iR2 || i_1_', 'iS <= iL2', 'i_1_ + iS == iL2'] +# CHECK-NEXT: 100 From f92028ab66ae2e60ff8d6e74bf70a1bae76eefd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Mon, 3 Nov 2025 16:24:38 +0100 Subject: [PATCH 12/23] Small fixes to descript_extend --- src/xtc/schedules/descript_extend.py | 58 ++++++++++------------------ 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 915a52a79..3df0e357c 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -66,6 +66,8 @@ def _parse_yaml(self, spec: dict[str, dict]) -> dict[str, dict]: out_dict = {} for level, d_level in spec.items(): level_dict = {} + if not isinstance(d_level, dict): + continue for a, v in d_level.items(): if a == "explore": assert isinstance(v, str) @@ -88,23 +90,8 @@ def _parse_yaml(self, spec: dict[str, dict]) -> dict[str, dict]: d = v size = d.get("size", None) if size: - # if d.get("split", None): - # raise Exception(f""" - # Axis cannot be tiled and split. - # (Axis {a} at level {level}.) - # """) d.pop("size") a = f"{a}#{size}" - # split = d.get("split", None) - # if split: - # if isinstance(split, str): - # d_split = self._split_yaml(split) - # else: - # d_split = split - # for sub_level, sub_range in d_split.items(): - # sub_dict = self._parse_yaml({sub_level : d[sub_level]}) - # level_dict[a + sub_range] = sub_dict - # continue if ":" in a: level_dict[a] = self._parse_yaml(d) continue @@ -358,12 +345,12 @@ def _flatten_schedule( tree_packs.append((param, input, pad)) if isinstance(param, str): variables.append(param) - constraints.append(f"0 <= {param} <= 1") + constraints.append(f"{param} in {{0, 1}}") if isinstance(input, str): input = self.abstract_matrix.index(input) if isinstance(pad, str): variables.append(pad) - constraints.append(f"0 <= {pad} <= 1") + constraints.append(f"{pad} in {{0, 1}}") continue elif declaration in "buffer": for val_ in val: @@ -375,10 +362,10 @@ def _flatten_schedule( tree_buff.append((param, pad)) if isinstance(param, str): variables.append(param) - constraints.append(f"0 <= {param} <= 1") + constraints.append(f"{param} in {{0, 1}}") if isinstance(pad, str): variables.append(pad) - constraints.append(f"0 <= {pad} <= 1") + constraints.append(f"{pad} in {{0, 1}}") continue elif declaration == "explore_axis_order": sched["axis_orders"].append(tree_declaration) @@ -394,10 +381,10 @@ def _flatten_schedule( tree_packs.append((param, matrix_index, pad)) if isinstance(param, str): variables.append(param) - constraints.append(f"0 <= {param} <= 1") + constraints.append(f"{param} in {{0, 1}}") if isinstance(pad, str): variables.append(pad) - constraints.append(f"0 <= {pad} <= 1") + constraints.append(f"{pad} in {{0, 1}}") continue elif ":" in declaration: axis_name, x, y, z = self.parse_split_declaration(declaration) @@ -555,23 +542,20 @@ def _flatten_schedule( for v in tree_interchange.values(): interchange += v - # Check if the last cut of each axis is either 0 or None. - # None correspond to "until the end of the loop". 0 is the - # default value, if it has 0 then it means the axis isn't splitted. - # Any other value means the split is let in a partial state. if last_split is not None: a, b = last_split - if isinstance(a, int): - constraints.append(f"{b} in {{{a}}}") - elif isinstance(b, int): - constraints.append(f"{a} in {{{b}}}") - else: - for i in range(len(constraints)): - c = constraints[i] - constraints[i] = c.replace(a, b) - # constraints.remove(c) c.replace() - # constraints.append(f"{b} == {a}") + if isinstance(a, int) and not isinstance(b, int): + a, b = b, a + a, b = str(a), str(b) + for i in range(len(constraints)): + c = constraints[i] + constraints[i] = c.replace(a, b) last_split = None + + # Check if the last cut of each axis is either 0 or None. + # None correspond to "until the end of the loop". 0 is the + # default value, if it has 0 then it means the axis isn't splitted. + # Any other value means the split is let in a partial state. for axis, cut in previous_cut.items(): if cut is not None and isinstance(cut, int) and cut != 0: raise Exception( @@ -676,7 +660,7 @@ def annotate( case "vectorize": if isinstance(param, str): sched["variables"].append(param) - sched["constraints"].append(f"0 <= {param} <= 1") + sched["constraints"].append(f"{param} in {{0, 1}}") sched["vectorize"].append((param, loop_name)) continue if param is None: @@ -689,7 +673,7 @@ def annotate( case "parallelize": if isinstance(param, str): sched["variables"].append(param) - sched["constraints"].append(f"0 <= {param} <= 1") + sched["constraints"].append(f"{param} in {{0, 1}}") sched["parallelize"].append((param, loop_name)) continue if param is None: From f9a2470dd74521b677ca6ef4e612288186246ff1 Mon Sep 17 00:00:00 2001 From: Sylvain Noiry Date: Wed, 12 Nov 2025 18:12:08 +0100 Subject: [PATCH 13/23] Bug fix --- src/xtc/schedules/descript_extend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 3df0e357c..a90e8126a 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -3,6 +3,7 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from typing import Any, Tuple +from copy import deepcopy from dataclasses import dataclass import re import strictyaml @@ -191,7 +192,7 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): def apply_sample( self, flat_schedules: list[SchedDict], sample: dict[str, Any] ) -> list[SchedDict]: - flat_schedules = flat_schedules.copy() + flat_schedules = deepcopy(flat_schedules) for schedule in flat_schedules: for k in ["splits", "tiles"]: for dim, axes in schedule[k].items(): From e80cfb6ef2722ef377f92310ed764094bc640ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Thu, 13 Nov 2025 16:56:34 +0100 Subject: [PATCH 14/23] Constraint fixes --- src/xtc/schedules/descript_extend.py | 28 +++++++++++++++++----------- src/xtc/search/strategies.py | 4 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index a90e8126a..bf3980912 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -173,17 +173,22 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): axis_constraints += all_axis_constraints axis_constraints.reverse() for constraint in axis_constraints: - constraint.reverse() - constraint_str = "" - var_flag = False - if isinstance(constraint[0], str): - constraint_str = "1 || " - for size in constraint[:-1]: - var_flag = var_flag or isinstance(size, str) - constraint_str += f"{size} || " - constraint_str += str(constraint[-1]) - if var_flag: - constraints.insert(0, constraint_str) + if constraint[0] == 1: + for size in constraint[1:]: + if isinstance(size, str): + constraints.append(f"{size} in {{1}}") + else: + constraint.reverse() + constraint_str = "" + var_flag = False + if isinstance(constraint[0], str): + constraint_str = "1 || " + for size in constraint[:-1]: + var_flag = var_flag or isinstance(size, str) + constraint_str += f"{size} || " + constraint_str += str(constraint[-1]) + if var_flag: + constraints.insert(0, constraint_str) variables = list(dict.fromkeys(variables)) constraints = list(dict.fromkeys(constraints)) @@ -436,6 +441,7 @@ def _flatten_schedule( .replace("[", "_") .replace("]", "_") ) + constraints.append(f"{inner_size} <= {y}") if isinstance(x, str): constraints.append(f"{x} <= {y}") # constraints.append(f"1 || {x} || {y}") diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index cff5a5973..ca244fb1c 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -983,7 +983,7 @@ def __init__( permutation = list(itertools.permutations(v)) a_holder = f"order_{a}" self._orders[a_holder] = permutation - order_constraints.append(f"0 <= {a_holder} <= {len(permutation) - 1}") + order_constraints.append(f"{a_holder} in {set(range(len(permutation)))}") self._constraints = constraints + input_constraints + order_constraints if initialize: self._initialize() @@ -1045,7 +1045,7 @@ def sample_once(self, num: int) -> Iterator[Sample]: def _sample_once_tuple(self, num: int) -> Iterator[tuple]: draw = self.sample_once(num) for d in draw: - yield tuple(d.values()) + yield tuple([d[x] for x in self.sample_names]) @override def exhaustive(self) -> Iterator[Sample]: From 3bebc3d0983e1d4496b9b5c6ad08323d70528cfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Mon, 17 Nov 2025 12:12:08 +0100 Subject: [PATCH 15/23] Partial tiles support --- src/xtc/schedules/descript_extend.py | 150 +++++++++++++-------------- src/xtc/search/strategies.py | 2 + 2 files changed, 76 insertions(+), 76 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index bf3980912..7ad044df4 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -22,11 +22,13 @@ def descript_extend_scheduler( spec: dict[str, dict], abstract_matrix: list[str] = [], sample: dict[str, Any] = {}, + partial_tiles: bool = False, ): descript = DescriptExtend( abstract_axis=abstract_axis, abstract_axis_sizes=abstract_axis_sizes, abstract_matrix=abstract_matrix, + partial_tiles=partial_tiles, ) descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) @@ -35,6 +37,7 @@ def descript_extend_scheduler( class DescriptExtend(Descript): abstract_axis_sizes: dict[str, int] abstract_matrix: list[str] + partial_tiles: bool = False @override def apply( @@ -135,60 +138,60 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): for axis in axis_orders: orders[axis] = schedule["axes"][axis] - for axis in self.abstract_axis: - all_axis_constraints = [] - for schedule in flat_schedules: - for sched in schedule["sizes"][axis]: - if len(sched) > 1: - all_axis_constraints.append(sched) - axis_constraints = [] - i = 0 - while i < len(all_axis_constraints): - sched = all_axis_constraints[i] - if isinstance(sched[0], int): - axis_constraints.append(sched) - all_axis_constraints.pop(i) - else: - i += 1 - flag_flag = True - while len(all_axis_constraints) > 0 and flag_flag: - i = 0 - axis_constraints_acc = [] - flag_flag = False - while i < len(all_axis_constraints): - sched = all_axis_constraints[i] - flag = False - for constraint in axis_constraints: - if sched[0] == constraint[-1]: - axis_constraints_acc.append(constraint + sched[1:]) - flag = True - if flag: - all_axis_constraints.pop(i) - flag_flag = True - else: - i += 1 - if flag_flag: - axis_constraints = axis_constraints_acc - - axis_constraints += all_axis_constraints - axis_constraints.reverse() - for constraint in axis_constraints: - if constraint[0] == 1: - for size in constraint[1:]: - if isinstance(size, str): - constraints.append(f"{size} in {{1}}") - else: - constraint.reverse() - constraint_str = "" - var_flag = False - if isinstance(constraint[0], str): - constraint_str = "1 || " - for size in constraint[:-1]: - var_flag = var_flag or isinstance(size, str) - constraint_str += f"{size} || " - constraint_str += str(constraint[-1]) - if var_flag: - constraints.insert(0, constraint_str) + # for axis in self.abstract_axis: + # all_axis_constraints = [] + # for schedule in flat_schedules: + # for sched in schedule["sizes"][axis]: + # if len(sched) > 1: + # all_axis_constraints.append(sched) + # axis_constraints = [] + # i = 0 + # while i < len(all_axis_constraints): + # sched = all_axis_constraints[i] + # if isinstance(sched[0], int): + # axis_constraints.append(sched) + # all_axis_constraints.pop(i) + # else: + # i += 1 + # flag_flag = True + # while len(all_axis_constraints) > 0 and flag_flag: + # i = 0 + # axis_constraints_acc = [] + # flag_flag = False + # while i < len(all_axis_constraints): + # sched = all_axis_constraints[i] + # flag = False + # for constraint in axis_constraints: + # if sched[0] == constraint[-1]: + # axis_constraints_acc.append(constraint + sched[1:]) + # flag = True + # if flag: + # all_axis_constraints.pop(i) + # flag_flag = True + # else: + # i += 1 + # if flag_flag: + # axis_constraints = axis_constraints_acc + # + # axis_constraints += all_axis_constraints + # axis_constraints.reverse() + # for constraint in axis_constraints: + # if constraint[0] == 1: + # for size in constraint[1:]: + # if isinstance(size, str): + # constraints.append(f"{size} in {{1}}") + # else: + # constraint.reverse() + # constraint_str = "" + # var_flag = False + # if isinstance(constraint[0], str): + # constraint_str = "1 || " + # for size in constraint[:-1]: + # var_flag = var_flag or isinstance(size, str) + # constraint_str += f"{size} || " + # constraint_str += str(constraint[-1]) + # if var_flag: + # constraints.insert(0, constraint_str) variables = list(dict.fromkeys(variables)) constraints = list(dict.fromkeys(constraints)) @@ -308,7 +311,6 @@ def _flatten_schedule( "axis_orders": [], "axes": {}, "splits": {}, - "sizes": {}, "tiles": {a: {} for a in self.abstract_axis}, "interchange": [], "vectorize": [], @@ -322,15 +324,12 @@ def _flatten_schedule( axes_sizes: dict[str, int | str] = tile_sizes else: axes_sizes = {a: v for a, v in self.abstract_axis_sizes.items()} - sched_sizes = {} - for a, v in axes_sizes.items(): - sched["sizes"][a] = [] - sched_sizes[a] = [v] sizes: dict[str, int | str | None] = {} previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = head constraints: list[str] = [] variables: list[str] = [] + default_leq = "<=" if self.partial_tiles else "||" # Processing the schedule for tree_declaration, tree_val in spec.items(): assert isinstance(tree_val, dict) @@ -444,8 +443,6 @@ def _flatten_schedule( constraints.append(f"{inner_size} <= {y}") if isinstance(x, str): constraints.append(f"{x} <= {y}") - # constraints.append(f"1 || {x} || {y}") - # sched_sizes[axis_name].append(x) constraints.append(f"{inner_size} + {x} == {y}") else: inner_size = z @@ -502,8 +499,16 @@ def _flatten_schedule( f"Invalid tile size: '{tile_size}' in {declaration}" ) + if isinstance(loop_size, str): + partial = "partial" in val + full = "full" in val + if partial and full: + raise Exception( + f"Tile {declaration} cannot be partial and full" + ) + leq = "||" if full else "<=" if partial else default_leq + constraints.append(f"{loop_size} {leq} {axes_sizes[axis_name]}") axes_sizes[axis_name] = loop_size - sched_sizes[axis_name].append(loop_size) tile_num = len(sched["tiles"][axis_name]) loop_name = f"{axis_name}{tile_num}" sched["tiles"][axis_name][loop_name] = loop_size @@ -533,11 +538,10 @@ def _flatten_schedule( self.annotate( loop_name=loop_name, - axis_name=axis_name, sizes=sizes, annotations=val, sched=sched, - sched_sizes=sched_sizes[axis_name], + constraints=constraints, ) sched["axes"][tree_declaration] = tree_interchange if len(tree_packs) > 0: @@ -572,14 +576,6 @@ def _flatten_schedule( sched["interchange"] = interchange sched["variables"] = variables + sched["variables"] sched["constraints"] = constraints + sched["constraints"] - for a in self.abstract_axis: - flag = True - for sched_ in sched["sizes"][a]: - if set(sched_sizes[a]) <= set(sched_): - flag = False - break - if flag: - sched["sizes"][a] = [sched_sizes[a]] + sched["sizes"][a] return [sched] + recursive_scheds def _extended_check_splitting_intervals( @@ -645,11 +641,10 @@ def _extended_check_splitting_intervals( def annotate( self, loop_name: str, - axis_name: str, sizes: dict[str, int | str | None], annotations: dict[str, Any], sched: dict[str, Any], - sched_sizes: list[int | str], + constraints: list[str], ): for instr, param in annotations.items(): assert isinstance(instr, str) @@ -661,7 +656,7 @@ def annotate( ufactor = param if isinstance(param, str): sched["variables"].append(param) - sched["sizes"][axis_name].append(sched_sizes + [ufactor]) + constraints.append(f"{ufactor} || {sizes[loop_name]}") sched["unroll"][loop_name] = ufactor case "vectorize": @@ -690,7 +685,10 @@ def annotate( raise Exception( "Parallelize should not have a parameter (Feature not implemented)" ) - + case "partial": + continue + case "full": + continue case _: raise Exception(f"Unknown annotation on {loop_name}: {instr}") diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index ca244fb1c..67ff1e535 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -951,6 +951,7 @@ def __init__( graph: Graph, spec: dict[str, dict] | str, constraints: list[str] = [], + partial_tiles: bool = False, initialize: bool = True, ) -> None: self._graph = graph @@ -962,6 +963,7 @@ def __init__( abstract_axis=self._axes, abstract_axis_sizes=dict(self._sizes), abstract_matrix=["A", "B", "C"], + partial_tiles=partial_tiles, ) self._descript = descript self._initialized = False From a5f0248ccc28399303d349fc9dab465cdff51b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Wed, 19 Nov 2025 12:14:52 +0100 Subject: [PATCH 16/23] Finishing partial tiles support --- src/xtc/schedules/descript_extend.py | 23 ++++++++++++++++++++--- src/xtc/search/strategies.py | 15 +++++++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 7ad044df4..a101f06f1 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -8,6 +8,7 @@ import re import strictyaml from typing_extensions import override +from copy import deepcopy from xtc.itf.schd.scheduler import Scheduler @@ -301,6 +302,7 @@ def _flatten_schedule( spec: dict[str, dict], head: list[str], tile_sizes: dict[str, int | str] | None = None, + sched_sizes: dict[str, list] | None = None, ) -> list[SchedDict]: recursive_scheds: list[SchedDict] = [] sched: SchedDict = { @@ -324,12 +326,15 @@ def _flatten_schedule( axes_sizes: dict[str, int | str] = tile_sizes else: axes_sizes = {a: v for a, v in self.abstract_axis_sizes.items()} + if sched_sizes is None: + sched_sizes = {} + for a, v in axes_sizes.items(): + sched_sizes[a] = [str(v)] sizes: dict[str, int | str | None] = {} previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = head constraints: list[str] = [] variables: list[str] = [] - default_leq = "<=" if self.partial_tiles else "||" # Processing the schedule for tree_declaration, tree_val in spec.items(): assert isinstance(tree_val, dict) @@ -480,6 +485,7 @@ def _flatten_schedule( root=new_axes_root_name, tile_sizes=axes_sizes.copy(), head=[axis_name], + sched_sizes=deepcopy(sched_sizes), ) axes_sizes[axis_name] = current_size @@ -506,8 +512,19 @@ def _flatten_schedule( raise Exception( f"Tile {declaration} cannot be partial and full" ) - leq = "||" if full else "<=" if partial else default_leq - constraints.append(f"{loop_size} {leq} {axes_sizes[axis_name]}") + if partial or (not full and self.partial_tiles): + constraints.append( + f"{loop_size} <= {axes_sizes[axis_name]}" + ) + else: + s = ( + ", ".join(sched_sizes[axis_name]) + if len(sched_sizes[axis_name]) > 1 + else sched_sizes[axis_name][0] + ) + s = f"{loop_size} || {{{s}}}" + constraints.append(s) + sched_sizes[axis_name].insert(0, str(loop_size)) axes_sizes[axis_name] = loop_size tile_num = len(sched["tiles"][axis_name]) loop_name = f"{axis_name}{tile_num}" diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 67ff1e535..58dfb9f91 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -11,7 +11,12 @@ from properties import constraints_from_str, hypergraph from properties import variables as sampler_variables -from strategy import execute_dynamic, execute_static, solve_with_z3 +from strategy import ( + execute_dynamic, + execute_static, + solve_with_z3, + pretty_print_methods, +) from xtc.itf.graph import Graph from xtc.itf.schd import Scheduler from xtc.itf.schd.scheduler import DEFAULT_ROOT @@ -993,7 +998,7 @@ def __init__( def _initialize(self): if self._initialized: return - max_enum = max(self._sizes.values()) + max_enum = int(1 + np.log2(max(self._sizes.values()))) constraints = constraints_from_str(self._constraints, silent=True) properties, constraints = hypergraph( constraints, max_enum=max_enum, silent=True @@ -1044,6 +1049,12 @@ def sample_once(self, num: int) -> Iterator[Sample]: ) return draw + def pretty_print_methods(self, tab: str = "\t"): + self._initialize() + pretty_print_methods( + self._methods, self._properties, self._constraints, tab=tab + ) + def _sample_once_tuple(self, num: int) -> Iterator[tuple]: draw = self.sample_once(num) for d in draw: From f65dbd69da51a1cbb466fa3590a7218a21e12981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Wed, 19 Nov 2025 16:05:30 +0100 Subject: [PATCH 17/23] Adding partial unrolls --- src/xtc/schedules/descript_extend.py | 6 +++++- src/xtc/search/strategies.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index a101f06f1..a2cf0c531 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -24,12 +24,14 @@ def descript_extend_scheduler( abstract_matrix: list[str] = [], sample: dict[str, Any] = {}, partial_tiles: bool = False, + partial_unrolls: bool = False, ): descript = DescriptExtend( abstract_axis=abstract_axis, abstract_axis_sizes=abstract_axis_sizes, abstract_matrix=abstract_matrix, partial_tiles=partial_tiles, + partial_unrolls=partial_unrolls, ) descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) @@ -39,6 +41,7 @@ class DescriptExtend(Descript): abstract_axis_sizes: dict[str, int] abstract_matrix: list[str] partial_tiles: bool = False + partial_unrolls: bool = False @override def apply( @@ -673,7 +676,8 @@ def annotate( ufactor = param if isinstance(param, str): sched["variables"].append(param) - constraints.append(f"{ufactor} || {sizes[loop_name]}") + leq = "<=" if self.partial_unrolls else "||" + constraints.append(f"{ufactor} {leq} {sizes[loop_name]}") sched["unroll"][loop_name] = ufactor case "vectorize": diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 58dfb9f91..dfcec345a 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -957,6 +957,7 @@ def __init__( spec: dict[str, dict] | str, constraints: list[str] = [], partial_tiles: bool = False, + partial_unrolls: bool = False, initialize: bool = True, ) -> None: self._graph = graph @@ -969,6 +970,7 @@ def __init__( abstract_axis_sizes=dict(self._sizes), abstract_matrix=["A", "B", "C"], partial_tiles=partial_tiles, + partial_unrolls=partial_unrolls, ) self._descript = descript self._initialized = False From bee2365bb97947fb32d185b0d8249202d83d31d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Wed, 19 Nov 2025 16:05:41 +0100 Subject: [PATCH 18/23] Update to yaml_goto test --- .../search/test_matmul_descript_yaml_goto.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py index bf37c01fb..0205371aa 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_goto.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -6,33 +6,61 @@ import utils from xtc.search.strategies import Strategy_Descript as Strategy +import xtc.graphs.xtc.op as O + graph = utils.get_graph_matmul() -backend = utils.get_backend(graph) +I, J, K, dtype = 1024, 1024, 1024, "float32" + +a = O.tensor((I, K), dtype) +b = O.tensor((K, J), dtype) + +with O.graph(name="matmul") as gb: + O.matmul(a, b) +graph = gb.graph +backend = utils.get_backend(graph, "tvm") + spec = """ -DDRj: +Memory: j: - parallelize: j_par -DDR: k: - i: - explore: True - A: bufferize=pack_A pad - B: bufferize=pack_B pad L3: - j: size=jL3 + B: bufferize + i: L2: - i: size=iL2 + A: bufferize + j#nc: + i#mc: L1: - k: size=kL1 unroll=kU -R: - i: size=iR unroll - j: size=jR vectorize=jV + k#kc: unroll=kr +Register: + i#mr: unroll full + j#nr: vectorize full """ -constraint = ["iR * jR <= 56"] -strategy = Strategy(graph, spec, constraints=constraint) + +nb_registers = 32 +nb_fma = 2 +fma_latency = 4 +ilp = nb_fma*fma_latency +vector_size = 16 +elt_size = 4 +reorder_buffer = 256 +nb_words_L1 = 32*1024//elt_size +nb_words_L2 = 1024*1024//elt_size +nb_words_L3 = 36*1024*1024//elt_size + +constraints = [ +f"1 + nvr + nvr * mr <= {nb_registers}", +f"nr == {vector_size} * nvr", +f"nvr * mr >= {ilp}", +f"nvr * mr * kr <= {reorder_buffer}", +f"kc * nr <= {nb_words_L1}", +f"kc * mc <= {nb_words_L2}", +f"kc * nc <= {nb_words_L3}", +] +strategy = Strategy(graph, spec, constraints=constraints, partial_tiles=False, partial_unrolls=False, initialize=False) print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['1 || kU || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_A <= 1', '0 <= pack_B <= 1', '0 <= j_par <= 1', '0 <= jV <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] +# CHECK: ['nc || {1024}', 'mc || {1024}', 'kc || {1024}', 'kr || kc', 'mr || {mc, 1024}', 'nr || {nc, 1024}', '1 + nvr + nvr * mr <= 32', 'nr == 16 * nvr', 'nvr * mr >= 8', 'nvr * mr * kr <= 256', 'kc * nr <= 8192', 'kc * mc <= 262144', 'kc * nc <= 9437184'] #CHECK-NEXT: 100 From 1395c3cdd8df5d3e34f19f37084969e46ebbabab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20Fr=C3=A9not?= Date: Wed, 26 Nov 2025 15:00:32 +0100 Subject: [PATCH 19/23] Sampler update --- src/xtc/schedules/descript_extend.py | 1 - src/xtc/search/strategies.py | 13 ++++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index a2cf0c531..258679a0b 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -8,7 +8,6 @@ import re import strictyaml from typing_extensions import override -from copy import deepcopy from xtc.itf.schd.scheduler import Scheduler diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index dfcec345a..7b621b96f 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -9,8 +9,7 @@ import itertools import numpy as np -from properties import constraints_from_str, hypergraph -from properties import variables as sampler_variables +from properties import constraints_from_str, hypergraph, Context from strategy import ( execute_dynamic, execute_static, @@ -1001,14 +1000,18 @@ def _initialize(self): if self._initialized: return max_enum = int(1 + np.log2(max(self._sizes.values()))) - constraints = constraints_from_str(self._constraints, silent=True) + context = Context() + constraints, self.constrants = constraints_from_str( + self._constraints, silent=True, context=context + ) properties, constraints = hypergraph( - constraints, max_enum=max_enum, silent=True + constraints, max_enum=max_enum, silent=True, context=context ) methods = solve_with_z3( - sampler_variables.keys(), properties, constraints, silent=True + context.variables.keys(), properties, constraints, silent=True ) enumerations = execute_static(methods, properties, constraints, silent=True) + self._context = context self._properties = properties self._z3_constraints = constraints self._methods = methods From b17f21932db8f7ea55f396dd0c72fc198d7bd347 Mon Sep 17 00:00:00 2001 From: Leon Frenot Date: Mon, 12 Jan 2026 16:22:31 +0100 Subject: [PATCH 20/23] Fixes after rebasing TODO: update tests --- requirements.txt | 1 - src/xtc/cli/mlir_loop.py | 2 +- src/xtc/schedules/descript.py | 21 +- src/xtc/schedules/descript_extend.py | 415 +++++++++++++-------------- src/xtc/search/strategies.py | 18 +- 5 files changed, 224 insertions(+), 233 deletions(-) diff --git a/requirements.txt b/requirements.txt index ed218fc34..6b3d57311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ ordered-set py-cpuinfo tqdm typing_extensions -xdsl~=0.50.0 pyyaml scikit-learn networkx diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index 95978b7a0..d15fe362f 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -49,7 +49,7 @@ def main(): node_name, always_vectorize=args.always_vectorize, concluding_passes=args.concluding_passes, - no_alias=args.no_alias, + no_alias=not args.alias, extend=args.extend, ) schedulers.append(sched) diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 9433b98a5..7f1873987 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -477,12 +477,12 @@ class LoopNestSlice: """ root: str - tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int]] = field(default_factory=dict) + tiles: dict[str, dict[str, int | str]] + splits: dict[str, dict[str, int | str]] = field(default_factory=dict) interchange: list[str] = field(default_factory=list) vectorize: list[str] = field(default_factory=list) parallelize: list[str] = field(default_factory=list) - unroll: dict[str, int] = field(default_factory=dict) + unroll: dict[str, int | str] = field(default_factory=dict) @property def splits_to_sizes(self) -> dict[str, int]: @@ -490,6 +490,7 @@ def splits_to_sizes(self) -> dict[str, int]: for axis in self.splits: last_start = None for loop_name, start in reversed(self.splits[axis].items()): + assert isinstance(start, int) if last_start is not None: size_of_split = last_start - start splits_to_sizes[loop_name] = size_of_split @@ -501,6 +502,7 @@ def tiles_to_sizes(self) -> dict[str, int]: tiles_to_sizes: dict[str, int] = {} for tiles in self.tiles.values(): for loop, size in tiles.items(): + assert isinstance(size, int) tiles_to_sizes[loop] = size return tiles_to_sizes @@ -568,7 +570,9 @@ def _check_tiling_consistency(self) -> None: `{axis}#{size}`: {axis} has not been materialized yet. """ ) - seen_axes[axis] = sched.tiles[axis][loop_name] + loop_size = sched.tiles[axis][loop_name] + assert isinstance(loop_size, int) + seen_axes[axis] = loop_size def _check_sizes(self): mapper = LoopsDimsMapper.build_from_slices(self.slices) @@ -589,6 +593,7 @@ def _check_sizes(self): current_size_of_split[loop_name] = None elif loop_name in mapper.tiles_to_axis: loop_size = sched.tiles[axis][loop_name] + assert isinstance(loop_size, int) LoopNest._must_be_smaller_routine( new_size=loop_size, current_sizes=current_sizes, @@ -649,6 +654,14 @@ def descript_scheduler( descript.apply(node_name=node_name, spec=spec) +def correct_type(d: dict[str, int | str]) -> dict[str, int]: + out_d: dict[str, int] = {} + for k, v in d.items(): + assert isinstance(v, int) + out_d[k] = v + return out_d + + @dataclass(frozen=True) class Descript: """Applies a parsed and interpreted schedule to a Scheduler. diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 258679a0b..15ebb4c16 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -4,14 +4,103 @@ # from typing import Any, Tuple from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field import re import strictyaml from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler -from xtc.schedules.descript import Descript, SchedDict +from xtc.schedules.descript import Descript, LoopNest, LoopNestSlice, correct_type + + +@dataclass +class LoopNestSliceExtend(LoopNestSlice): + axis_orders: list[str] = field(default_factory=list) + axes: dict[str, dict] = field(default_factory=dict) + packs: dict[str, list] = field(default_factory=dict) + buffers: dict[str, list] = field(default_factory=dict) + fusions: dict[str, list] = field(default_factory=dict) + variables: set[str] = field(default_factory=set) + constraints: set[str] = field(default_factory=set) + vectorize_bool: set[tuple[str, str]] = field(default_factory=set) + parallelize_bool: set[tuple[str, str]] = field(default_factory=set) + + +@dataclass +class LoopNestExtend(LoopNest): + @override + def build_slice(self, root: str) -> LoopNestSliceExtend: + slice = LoopNestSliceExtend( + root=root, tiles={a: {} for a in self.abstract_dims} + ) + self.slices = [slice] + self.slices + return slice + + def apply_sample(self, sample: dict[str, Any]): + for schedule in self.slices: + for dim, axes in schedule.splits.items(): + for level, size in axes.items(): + if isinstance(size, str): + schedule.splits[dim][level] = sample[size] + for dim, axes in schedule.tiles.items(): + for level, size in axes.items(): + if isinstance(size, str): + schedule.tiles[dim][level] = sample[size] + for axis, size in schedule.unroll.items(): + if isinstance(size, str): + val = sample[size] + if val is None: + for s__ in schedule.tiles.values(): + for level, size in s__.items(): + if axis == level: + val = size + break + if val is not None: + break + schedule.unroll[axis] = val + if isinstance(schedule, LoopNestSliceExtend): + for axis, loop in schedule.vectorize_bool: + axis = sample.get(axis, False) + if axis is None or axis: + schedule.vectorize.append(loop) + for axis, loop in schedule.parallelize_bool: + axis = sample.get(axis, False) + if axis is None or axis: + schedule.parallelize.append(loop) + for dim, packs in schedule.packs.items(): + for i, (flag, input, pad) in enumerate(packs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule.packs[dim].pop(i) + continue + if isinstance(input, str): + input = sample.get(input, input) + sample_flag = True + if sample_flag: + schedule.packs[dim][i] = (flag, input, pad) + for dim, buffs in schedule.buffers.items(): + for i, (flag, pad) in enumerate(buffs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule.buffers[dim].pop(i) + continue + if sample_flag: + schedule.buffers[dim][i] = (flag, pad) + for dim, axes in schedule.axes.items(): + d_holder = f"order_{dim}" + s = sample.get(d_holder, None) + if s: + sch = {} + for a in s: + sch[a] = axes[a] + schedule.axes[dim] = sch def descript_extend_scheduler( @@ -57,9 +146,10 @@ def apply( flat_schedules = self._flatten_schedule(root=node_name, spec=dict_spec, head=[]) variables = set() constraints = set() - for schedule in flat_schedules: - variables.update(schedule["variables"]) - constraints.update(schedule["constraints"]) + for schedule in flat_schedules.slices: + if isinstance(schedule, LoopNestSliceExtend): + variables.update(schedule.variables) + constraints.update(schedule.constraints) flat_schedules = self.apply_sample(flat_schedules, sample) self.apply_scheduler(flat_schedules, scheduler) @@ -132,170 +222,61 @@ def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): constraints = [] axes = {} orders = {} - for schedule in flat_schedules: - variables += schedule["variables"] - constraints += schedule["constraints"] - for axis, order in schedule["axes"].items(): - axes[f"order_{axis}"] = order - axis_orders = schedule["axis_orders"] - for axis in axis_orders: - orders[axis] = schedule["axes"][axis] - - # for axis in self.abstract_axis: - # all_axis_constraints = [] - # for schedule in flat_schedules: - # for sched in schedule["sizes"][axis]: - # if len(sched) > 1: - # all_axis_constraints.append(sched) - # axis_constraints = [] - # i = 0 - # while i < len(all_axis_constraints): - # sched = all_axis_constraints[i] - # if isinstance(sched[0], int): - # axis_constraints.append(sched) - # all_axis_constraints.pop(i) - # else: - # i += 1 - # flag_flag = True - # while len(all_axis_constraints) > 0 and flag_flag: - # i = 0 - # axis_constraints_acc = [] - # flag_flag = False - # while i < len(all_axis_constraints): - # sched = all_axis_constraints[i] - # flag = False - # for constraint in axis_constraints: - # if sched[0] == constraint[-1]: - # axis_constraints_acc.append(constraint + sched[1:]) - # flag = True - # if flag: - # all_axis_constraints.pop(i) - # flag_flag = True - # else: - # i += 1 - # if flag_flag: - # axis_constraints = axis_constraints_acc - # - # axis_constraints += all_axis_constraints - # axis_constraints.reverse() - # for constraint in axis_constraints: - # if constraint[0] == 1: - # for size in constraint[1:]: - # if isinstance(size, str): - # constraints.append(f"{size} in {{1}}") - # else: - # constraint.reverse() - # constraint_str = "" - # var_flag = False - # if isinstance(constraint[0], str): - # constraint_str = "1 || " - # for size in constraint[:-1]: - # var_flag = var_flag or isinstance(size, str) - # constraint_str += f"{size} || " - # constraint_str += str(constraint[-1]) - # if var_flag: - # constraints.insert(0, constraint_str) + for schedule in flat_schedules.slices: + if isinstance(schedule, LoopNestSliceExtend): + variables += schedule.variables + constraints += schedule.constraints + for axis, order in schedule.axes.items(): + axes[f"order_{axis}"] = order + axis_orders = schedule.axis_orders + for axis in axis_orders: + orders[axis] = schedule.axes[axis] variables = list(dict.fromkeys(variables)) constraints = list(dict.fromkeys(constraints)) return (flat_schedules, variables, constraints, axes, orders) def apply_sample( - self, flat_schedules: list[SchedDict], sample: dict[str, Any] - ) -> list[SchedDict]: + self, flat_schedules: LoopNestExtend, sample: dict[str, Any] + ) -> LoopNestExtend: flat_schedules = deepcopy(flat_schedules) - for schedule in flat_schedules: - for k in ["splits", "tiles"]: - for dim, axes in schedule[k].items(): - for level, size in axes.items(): - if isinstance(size, str): - schedule[k][dim][level] = sample[size] - for k in ["vectorize", "parallelize"]: - for i, axes in enumerate(schedule[k]): - if isinstance(axes, Tuple): - axes, loop = axes - axes = sample.get(axes, False) - if axes is None or axes: - schedule[k][i] = loop - else: - schedule[k].pop(i) - for axis, size in schedule["unroll"].items(): - if isinstance(size, str): - val = sample[size] - if val is None: - for s__ in schedule["tiles"].values(): - for level, size in s__.items(): - if axis == level: - val = size - break - if val is not None: - break - schedule["unroll"][axis] = val - for dim, packs in schedule["packs"].items(): - for i, (flag, input, pad) in enumerate(packs): - sample_flag = False - if isinstance(flag, str): - flag = sample.get(flag, False) - sample_flag = True - if not flag: - schedule["packs"][dim].pop(i) - continue - if isinstance(input, str): - input = sample.get(input, input) - sample_flag = True - if sample_flag: - schedule["packs"][dim][i] = (flag, input, pad) - for dim, buffs in schedule["buffers"].items(): - for i, (flag, pad) in enumerate(buffs): - sample_flag = False - if isinstance(flag, str): - flag = sample.get(flag, False) - sample_flag = True - if not flag: - schedule["buffers"][dim].pop(i) - continue - if sample_flag: - schedule["buffers"][dim][i] = (flag, pad) - for dim, axes in schedule["axes"].items(): - d_holder = f"order_{dim}" - s = sample.get(d_holder, None) - if s: - sch = {} - for a in s: - sch[a] = axes[a] - schedule["axes"][dim] = sch + flat_schedules.apply_sample(sample) return flat_schedules - def apply_scheduler(self, flat_schedules: list[SchedDict], scheduler: Scheduler): - self._check_flattened_schedule(flat_schedules) - for schedule in flat_schedules: - root = schedule["root"] + def apply_scheduler(self, flat_schedules: LoopNestExtend, scheduler: Scheduler): + flat_schedules.check() + for schedule in flat_schedules.slices: + assert isinstance(schedule, LoopNestSliceExtend) + root = schedule.root interchange = [] - for d, s in schedule["axes"].items(): + for d, s in schedule.axes.items(): s = list(s.values()) for s in s: interchange += s - p = schedule["packs"].get(d, None) + p = schedule.packs.get(d, None) if p: for _, input, pad in p: scheduler.pack_at(s[-1], input, pad=pad) - b = schedule["buffers"].get(d, None) + b = schedule.buffers.get(d, None) if b: scheduler.buffer_at(s[-1]) - for d, s in schedule["splits"].items(): + for d, s in schedule.splits.items(): + s = correct_type(s) scheduler.split(d, s, root=root) - for d, s in schedule["tiles"].items(): + for d, s in schedule.tiles.items(): + s = correct_type(s) scheduler.tile(d, s, root=root) scheduler.interchange(interchange, root=root) - scheduler.vectorize(schedule["vectorize"], root=root) - scheduler.parallelize(schedule["parallelize"], root=root) - scheduler.unroll(schedule["unroll"], root=root) + scheduler.vectorize(schedule.vectorize, root=root) + scheduler.parallelize(schedule.parallelize, root=root) + s = correct_type(schedule.unroll) + scheduler.unroll(s, root=root) @override def _flatten_schedule( @@ -305,24 +286,25 @@ def _flatten_schedule( head: list[str], tile_sizes: dict[str, int | str] | None = None, sched_sizes: dict[str, list] | None = None, - ) -> list[SchedDict]: - recursive_scheds: list[SchedDict] = [] - sched: SchedDict = { - "root": root, - "fusions": {}, - "packs": {}, - "buffers": {}, - "axis_orders": [], - "axes": {}, - "splits": {}, - "tiles": {a: {} for a in self.abstract_axis}, - "interchange": [], - "vectorize": [], - "parallelize": [], - "unroll": {}, - "variables": [], - "constraints": [], - } + ) -> LoopNestExtend: + recursive_scheds = LoopNestExtend(abstract_dims=self.abstract_axis) + sched = recursive_scheds.build_slice(root) + # sched: SchedDict = { + # "root": root, + # "fusions": {}, + # "packs": {}, + # "buffers": {}, + # "axis_orders": [], + # "axes": {}, + # "splits": {}, + # "tiles": {a: {} for a in self.abstract_axis}, + # "interchange": [], + # "vectorize": [], + # "parallelize": [], + # "unroll": {}, + # "variables": [], + # "constraints": [], + # } # State of the schedule if tile_sizes: axes_sizes: dict[str, int | str] = tile_sizes @@ -335,8 +317,8 @@ def _flatten_schedule( sizes: dict[str, int | str | None] = {} previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = head - constraints: list[str] = [] - variables: list[str] = [] + # constraints: list[str] = [] + # variables: list[str] = [] # Processing the schedule for tree_declaration, tree_val in spec.items(): assert isinstance(tree_val, dict) @@ -356,13 +338,13 @@ def _flatten_schedule( param, input, pad = val_ tree_packs.append((param, input, pad)) if isinstance(param, str): - variables.append(param) - constraints.append(f"{param} in {{0, 1}}") + sched.variables.add(param) + sched.constraints.add(f"{param} in {{0, 1}}") if isinstance(input, str): input = self.abstract_matrix.index(input) if isinstance(pad, str): - variables.append(pad) - constraints.append(f"{pad} in {{0, 1}}") + sched.variables.add(pad) + sched.constraints.add(f"{pad} in {{0, 1}}") continue elif declaration in "buffer": for val_ in val: @@ -373,14 +355,14 @@ def _flatten_schedule( param, pad = val_ tree_buff.append((param, pad)) if isinstance(param, str): - variables.append(param) - constraints.append(f"{param} in {{0, 1}}") + sched.variables.add(param) + sched.constraints.add(f"{param} in {{0, 1}}") if isinstance(pad, str): - variables.append(pad) - constraints.append(f"{pad} in {{0, 1}}") + sched.variables.add(pad) + sched.constraints.add(f"{pad} in {{0, 1}}") continue elif declaration == "explore_axis_order": - sched["axis_orders"].append(tree_declaration) + sched.axis_orders.append(tree_declaration) continue elif declaration in self.abstract_matrix: matrix_index = self.abstract_matrix.index(declaration) @@ -392,11 +374,11 @@ def _flatten_schedule( else: tree_packs.append((param, matrix_index, pad)) if isinstance(param, str): - variables.append(param) - constraints.append(f"{param} in {{0, 1}}") + sched.variables.add(param) + sched.constraints.add(f"{param} in {{0, 1}}") if isinstance(pad, str): - variables.append(pad) - constraints.append(f"{pad} in {{0, 1}}") + sched.variables.add(pad) + sched.constraints.add(f"{pad} in {{0, 1}}") continue elif ":" in declaration: axis_name, x, y, z = self.parse_split_declaration(declaration) @@ -409,9 +391,9 @@ def _flatten_schedule( current_size = axes_sizes[axis_name] # Update the previous cut # Save the cutting points of the new dimensions - if axis_name not in sched["splits"]: - sched["splits"][axis_name] = {} - new_dim_index = len(sched["splits"][axis_name]) + if axis_name not in sched.splits: + sched.splits[axis_name] = {} + new_dim_index = len(sched.splits[axis_name]) new_dim_name = f"{axis_name}[{new_dim_index}]" new_axes_root_name = f"{root}/{new_dim_name}" if axis_name in tree_interchange: @@ -421,11 +403,12 @@ def _flatten_schedule( if z is None: previous_cut[axis_name] = y - sched["splits"][axis_name][new_dim_name] = x # When x (the starting point of the slice), is not # specified, it is the previous cut if x is None: x = cut + assert isinstance(x, int | str) + sched.splits[axis_name][new_dim_name] = x # assert isinstance(x, int) inner_size = self._extended_check_splitting_intervals( @@ -447,10 +430,10 @@ def _flatten_schedule( .replace("[", "_") .replace("]", "_") ) - constraints.append(f"{inner_size} <= {y}") + sched.constraints.add(f"{inner_size} <= {y}") if isinstance(x, str): - constraints.append(f"{x} <= {y}") - constraints.append(f"{inner_size} + {x} == {y}") + sched.constraints.add(f"{x} <= {y}") + sched.constraints.add(f"{inner_size} + {x} == {y}") else: inner_size = z x = cut @@ -458,11 +441,11 @@ def _flatten_schedule( if isinstance(z, int) and isinstance(x, int): previous_cut[axis_name] = x + z if not isinstance(y, int): - constraints.append(f"{z + x} <= {y}") + sched.constraints.add(f"{z + x} <= {y}") elif isinstance(x, int) and x == 0: previous_cut[axis_name] = z if not isinstance(y, int): - constraints.append(f"{z} <= {y}") + sched.constraints.add(f"{z} <= {y}") else: new_cut = root[1:] + new_dim_name new_cut = ( @@ -473,9 +456,9 @@ def _flatten_schedule( previous_cut[axis_name] = new_cut if last_split is not None: a, b = last_split - constraints.append(f"{a} <= {b}") + sched.constraints.add(f"{a} <= {b}") last_split = (new_cut, y) - constraints.append(f"{z} + {x} == {new_cut}") + sched.constraints.add(f"{z} + {x} == {new_cut}") axes_sizes[axis_name] = inner_size @@ -491,8 +474,9 @@ def _flatten_schedule( ) axes_sizes[axis_name] = current_size - recursive_scheds += inner_scheds + recursive_scheds.slices += inner_scheds.slices continue + elif "#" in declaration: axis_name, tile_size = declaration.split("#") self._check_axis_existence(axis_name) @@ -501,7 +485,7 @@ def _flatten_schedule( loop_size = int(tile_size) else: loop_size = tile_size - variables.append(tile_size) + sched.variables.add(tile_size) if not loop_size: raise Exception( f"Invalid tile size: '{tile_size}' in {declaration}" @@ -515,7 +499,7 @@ def _flatten_schedule( f"Tile {declaration} cannot be partial and full" ) if partial or (not full and self.partial_tiles): - constraints.append( + sched.constraints.add( f"{loop_size} <= {axes_sizes[axis_name]}" ) else: @@ -525,12 +509,12 @@ def _flatten_schedule( else sched_sizes[axis_name][0] ) s = f"{loop_size} || {{{s}}}" - constraints.append(s) + sched.constraints.add(s) sched_sizes[axis_name].insert(0, str(loop_size)) axes_sizes[axis_name] = loop_size - tile_num = len(sched["tiles"][axis_name]) + tile_num = len(sched.tiles[axis_name]) loop_name = f"{axis_name}{tile_num}" - sched["tiles"][axis_name][loop_name] = loop_size + sched.tiles[axis_name][loop_name] = loop_size sizes[loop_name] = loop_size if axis_name in tree_interchange: raise Exception( @@ -560,15 +544,14 @@ def _flatten_schedule( sizes=sizes, annotations=val, sched=sched, - constraints=constraints, ) - sched["axes"][tree_declaration] = tree_interchange + sched.axes[tree_declaration] = tree_interchange if len(tree_packs) > 0: - sched["packs"][tree_declaration] = tree_packs + sched.packs[tree_declaration] = tree_packs if len(tree_fusion) > 0: - sched["fusions"][tree_declaration] = tree_fusion + sched.fusions[tree_declaration] = tree_fusion if len(tree_buff) > 0: - sched["buffers"][tree_declaration] = tree_buff + sched.buffers[tree_declaration] = tree_buff for v in tree_interchange.values(): interchange += v @@ -577,9 +560,9 @@ def _flatten_schedule( if isinstance(a, int) and not isinstance(b, int): a, b = b, a a, b = str(a), str(b) - for i in range(len(constraints)): - c = constraints[i] - constraints[i] = c.replace(a, b) + for c in sched.constraints: + sched.constraints.remove(c) + sched.constraints.add(c.replace(a, b)) last_split = None # Check if the last cut of each axis is either 0 or None. @@ -592,10 +575,8 @@ def _flatten_schedule( f"Splitting on axis {axis} should end but stops at {cut}" ) - sched["interchange"] = interchange - sched["variables"] = variables + sched["variables"] - sched["constraints"] = constraints + sched["constraints"] - return [sched] + recursive_scheds + sched.interchange = interchange + return recursive_scheds def _extended_check_splitting_intervals( self, @@ -662,8 +643,7 @@ def annotate( loop_name: str, sizes: dict[str, int | str | None], annotations: dict[str, Any], - sched: dict[str, Any], - constraints: list[str], + sched: LoopNestSliceExtend, ): for instr, param in annotations.items(): assert isinstance(instr, str) @@ -674,19 +654,20 @@ def annotate( else: ufactor = param if isinstance(param, str): - sched["variables"].append(param) + sched.variables.add(param) leq = "<=" if self.partial_unrolls else "||" - constraints.append(f"{ufactor} {leq} {sizes[loop_name]}") - sched["unroll"][loop_name] = ufactor + sched.constraints.add(f"{ufactor} {leq} {sizes[loop_name]}") + assert isinstance(ufactor, int | str) + sched.unroll[loop_name] = ufactor case "vectorize": if isinstance(param, str): - sched["variables"].append(param) - sched["constraints"].append(f"{param} in {{0, 1}}") - sched["vectorize"].append((param, loop_name)) + sched.variables.add(param) + sched.constraints.add(f"{param} in {{0, 1}}") + sched.vectorize_bool.add((param, loop_name)) continue if param is None: - sched["vectorize"].append(loop_name) + sched.vectorize.append(loop_name) continue raise Exception( "Vectorize should not have a parameter (Feature not implemented)" @@ -694,12 +675,12 @@ def annotate( case "parallelize": if isinstance(param, str): - sched["variables"].append(param) - sched["constraints"].append(f"{param} in {{0, 1}}") - sched["parallelize"].append((param, loop_name)) + sched.variables.add(param) + sched.constraints.add(f"{param} in {{0, 1}}") + sched.parallelize_bool.add((param, loop_name)) continue if param is None: - sched["parallelize"].append(loop_name) + sched.parallelize.append(loop_name) continue if param is not None: raise Exception( diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 7b621b96f..bddb1f53a 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -9,8 +9,8 @@ import itertools import numpy as np -from properties import constraints_from_str, hypergraph, Context -from strategy import ( +from xvs.properties import constraints_from_str, hypergraph, Context +from xvs.strategy import ( execute_dynamic, execute_static, solve_with_z3, @@ -993,6 +993,7 @@ def __init__( self._orders[a_holder] = permutation order_constraints.append(f"{a_holder} in {set(range(len(permutation)))}") self._constraints = constraints + input_constraints + order_constraints + self._constraints.sort() if initialize: self._initialize() @@ -1002,15 +1003,13 @@ def _initialize(self): max_enum = int(1 + np.log2(max(self._sizes.values()))) context = Context() constraints, self.constrants = constraints_from_str( - self._constraints, silent=True, context=context + self._constraints, context=context ) properties, constraints = hypergraph( - constraints, max_enum=max_enum, silent=True, context=context + constraints, max_enum=max_enum, context=context ) - methods = solve_with_z3( - context.variables.keys(), properties, constraints, silent=True - ) - enumerations = execute_static(methods, properties, constraints, silent=True) + methods = solve_with_z3(list(context.variables.keys()), properties, constraints) + enumerations = execute_static(methods, properties, constraints) self._context = context self._properties = properties self._z3_constraints = constraints @@ -1050,14 +1049,13 @@ def sample_once(self, num: int) -> Iterator[Sample]: self._z3_constraints, self._enumerations, k=num, - silent=True, ) return draw def pretty_print_methods(self, tab: str = "\t"): self._initialize() pretty_print_methods( - self._methods, self._properties, self._constraints, tab=tab + self._methods, self._properties, self._z3_constraints, tab=tab ) def _sample_once_tuple(self, num: int) -> Iterator[tuple]: From 14116cd7007dd0ef71f81d5c164cfe1f07bb28ea Mon Sep 17 00:00:00 2001 From: Leon Frenot Date: Mon, 12 Jan 2026 17:14:49 +0100 Subject: [PATCH 21/23] Test fixes after rebasing --- .../splitting/v_splitting_extend.mlir | 10 +- ...test_matmul_descript_extend_mlir_sample.py | 76 +++--- .../test_matmul_descript_extend_mlir_split.py | 192 +++++++------- .../test_matmul_descript_extend_tvm_goto.py | 237 +++++++++++------- ...est_matmul_descript_extend_tvm_strategy.py | 10 +- .../search/test_matmul_descript_3axes.py | 2 +- .../search/test_matmul_descript_goto.py | 2 +- .../search/test_matmul_descript_simple.py | 2 +- .../search/test_matmul_descript_split.py | 2 +- .../test_matmul_descript_split_in_split.py | 2 +- .../search/test_matmul_descript_yaml_goto.py | 2 +- .../test_matmul_descript_yaml_simple.py | 2 +- .../search/test_matmul_descript_yaml_split.py | 2 +- 13 files changed, 281 insertions(+), 260 deletions(-) diff --git a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir index 3a644a0b3..cb29c2ca2 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir @@ -26,14 +26,12 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 // CHECK-NEXT: } // CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { // CHECK-NEXT: %0 = transform.structured.match attributes {__node0__} in %arg0 : (!transform.any_op) -> !transform.any_op -// CHECK-NEXT: %first, %second = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op -// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %first tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: %1 = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op +// CHECK-NEXT: %2:2 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2#0 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // CHECK-NEXT: transform.annotate %loops "__node0__/i[0]/j" : !transform.any_op -// CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %second tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %2#1 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // CHECK-NEXT: transform.annotate %loops_1 "__node0__/i[1]/j" : !transform.any_op -// CHECK-NEXT: %2 = transform.get_parent_op %loops_1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -// CHECK-NEXT: %3 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op // CHECK-NEXT: transform.yield // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py index 2dc8a5a8a..3246d7022 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py @@ -50,7 +50,7 @@ res = executor.execute() print(f"CODE: {res}") -#CHECK: // -----// IR Dump Before transform //----- // +#CHECK:// -----// IR Dump Before transform //----- // #CHECK-NEXT: module attributes {transform.with_named_sequence} { #CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { #CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 @@ -68,9 +68,8 @@ #CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op @@ -78,26 +77,26 @@ #CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op -#CHECK-NEXT: %3 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () -#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %2 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %2 { #CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract #CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns #CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %3 { +#CHECK-NEXT: transform.apply_patterns to %2 { #CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct #CHECK-NEXT: transform.apply_patterns.vector.lower_contraction #CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i0"} in %3 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op #CHECK-NEXT: transform.yield #CHECK-NEXT: } #CHECK-NEXT: } -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: // -----// IR Dump After transform //----- // #CHECK-NEXT: module attributes {transform.with_named_sequence} { #CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { #CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +#CHECK-NEXT: %0 = ub.poison : f32 #CHECK-NEXT: %c16 = arith.constant 16 : index #CHECK-NEXT: %c2 = arith.constant 2 : index #CHECK-NEXT: %c512 = arith.constant 512 : index @@ -123,41 +122,37 @@ #CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { #CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %c2_7 = arith.constant 2 : index -#CHECK-NEXT: %subview_8 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_9 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %0 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %2 = vector.transfer_read %subview_9[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %3 = vector.extract %1[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<16xf32> -#CHECK-NEXT: %6 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<16xf32> -#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<16xf32> into vector<1x16xf32> -#CHECK-NEXT: vector.transfer_write %8, %subview_9[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %c1_10 = arith.constant 1 : index -#CHECK-NEXT: %9 = arith.muli %c1, %c1_10 : index -#CHECK-NEXT: %10 = arith.addi %c0, %9 : index -#CHECK-NEXT: %subview_11 = memref.subview %subview_3[%10, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_12 = memref.subview %subview_6[%10, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %11 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %13 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> -#CHECK-NEXT: %14 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<16xf32> -#CHECK-NEXT: %17 = vector.extract %13[0] : vector<16xf32> from vector<1x16xf32> -#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<16xf32> -#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<16xf32> into vector<1x16xf32> -#CHECK-NEXT: vector.transfer_write %19, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_7 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_8 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_7[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_8[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_8[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_9 = memref.subview %subview_3[%c1, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_6[%c1, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_9[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_10[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: } {"C/j"} #CHECK-NEXT: } {"C/i"} #CHECK-NEXT: } {"C/k"} #CHECK-NEXT: return #CHECK-NEXT: } #CHECK-NEXT: } -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: graph: #CHECK-NEXT: name: matmul #CHECK-NEXT: inputs: @@ -167,5 +162,6 @@ #CHECK-NEXT: - %2 : 4x32xfloat32 #CHECK-NEXT: nodes: #CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: CODE: 0 + diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py index 1bd14f14e..b1d41aba2 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py @@ -78,63 +78,44 @@ #CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op -#CHECK-NEXT: %1 = transform.get_parent_op %loops {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %2 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op #CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_7 "C/j0" : !transform.any_op -#CHECK-NEXT: %first, %second = transform.structured.split %tiled_linalg_op_6 after 4 {dimension = 0 : i64} : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %first tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %2 = transform.structured.split %tiled_linalg_op_6 after 4 {dimension = 0 : i64} : !transform.any_op +#CHECK-NEXT: %3:2 = transform.split_handle %2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %3#0 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i0" : !transform.any_op -#CHECK-NEXT: %3 = transform.get_parent_op %loops_9 {isolated_from_above} : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () -#CHECK-NEXT: transform.apply_patterns to %3 { -#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %3 { -#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: %4 = transform.structured.match attributes {"C/i[0]/i0"} in %3 : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %second tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %3#1 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) #CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op -#CHECK-NEXT: %5 = transform.get_parent_op %loops_11 {isolated_from_above} : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () -#CHECK-NEXT: transform.apply_patterns to %5 { -#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract -#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %5 { -#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct -#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction -#CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: %6 = transform.structured.match attributes {"C/i[1]/i0"} in %5 : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 4 : i64} : !transform.any_op -#CHECK-NEXT: %7 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %4 { #CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract #CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns #CHECK-NEXT: } : !transform.any_op -#CHECK-NEXT: transform.apply_patterns to %7 { +#CHECK-NEXT: transform.apply_patterns to %4 { #CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct #CHECK-NEXT: transform.apply_patterns.vector.lower_contraction #CHECK-NEXT: } : !transform.any_op #CHECK-NEXT: transform.yield #CHECK-NEXT: } #CHECK-NEXT: } -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: // -----// IR Dump After transform //----- // #CHECK-NEXT: module attributes {transform.with_named_sequence} { #CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x4xf32> #CHECK-NEXT: %c3 = arith.constant 3 : index #CHECK-NEXT: %c12 = arith.constant 12 : index +#CHECK-NEXT: %0 = ub.poison : f32 #CHECK-NEXT: %c2 = arith.constant 2 : index -#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x4xf32> #CHECK-NEXT: %c4 = arith.constant 4 : index #CHECK-NEXT: %c512 = arith.constant 512 : index #CHECK-NEXT: %c32 = arith.constant 32 : index @@ -164,84 +145,84 @@ #CHECK-NEXT: scf.for %arg6 = %c0 to %c4 step %c2 { #CHECK-NEXT: %subview_11 = memref.subview %subview_7[%arg6, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> #CHECK-NEXT: %subview_12 = memref.subview %subview_8[%arg6, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %0 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %2 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %3 = vector.extract %1[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<4xf32> -#CHECK-NEXT: %6 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<4xf32> -#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %8, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %9 = arith.addi %arg6, %c1 : index -#CHECK-NEXT: %subview_13 = memref.subview %subview_7[%9, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_14 = memref.subview %subview_8[%9, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %10 = vector.transfer_read %subview_13[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> -#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> -#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %18, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_11[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = arith.addi %arg6, %c1 : index +#CHECK-NEXT: %subview_13 = memref.subview %subview_7[%10, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_8[%10, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %11 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<4xf32> +#CHECK-NEXT: %17 = vector.extract %13[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<4xf32> +#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %19, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: } {"C/i[0]/i0"} #CHECK-NEXT: %subview_9 = memref.subview %subview_3[4, 0] [12, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<12x1xf32, strided<[512, 1], offset: ?>> #CHECK-NEXT: %subview_10 = memref.subview %subview_6[4, 0] [12, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<12x4xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: scf.for %arg6 = %c0 to %c12 step %c4 { #CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> #CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %0 = vector.transfer_read %subview_11[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %1 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %2 = vector.transfer_read %subview_12[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %3 = vector.extract %1[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %4 = vector.extract %0[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %5 = vector.broadcast %4 : f32 to vector<4xf32> -#CHECK-NEXT: %6 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %7 = vector.fma %5, %3, %6 : vector<4xf32> -#CHECK-NEXT: %8 = vector.insert %7, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %8, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %9 = arith.addi %arg6, %c1 : index -#CHECK-NEXT: %subview_13 = memref.subview %subview_9[%9, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_14 = memref.subview %subview_10[%9, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %10 = vector.transfer_read %subview_13[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_14[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> -#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> -#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %18, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %19 = arith.addi %arg6, %c2 : index -#CHECK-NEXT: %subview_15 = memref.subview %subview_9[%19, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_16 = memref.subview %subview_10[%19, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %20 = vector.transfer_read %subview_15[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %21 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %22 = vector.transfer_read %subview_16[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %23 = vector.extract %21[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %24 = vector.extract %20[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %25 = vector.broadcast %24 : f32 to vector<4xf32> -#CHECK-NEXT: %26 = vector.extract %22[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %27 = vector.fma %25, %23, %26 : vector<4xf32> -#CHECK-NEXT: %28 = vector.insert %27, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %28, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %29 = arith.addi %arg6, %c3 : index -#CHECK-NEXT: %subview_17 = memref.subview %subview_9[%29, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_18 = memref.subview %subview_10[%29, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %30 = vector.transfer_read %subview_17[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %31 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %32 = vector.transfer_read %subview_18[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %33 = vector.extract %31[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %34 = vector.extract %30[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %35 = vector.broadcast %34 : f32 to vector<4xf32> -#CHECK-NEXT: %36 = vector.extract %32[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %37 = vector.fma %35, %33, %36 : vector<4xf32> -#CHECK-NEXT: %38 = vector.insert %37, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %38, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_11[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = arith.addi %arg6, %c1 : index +#CHECK-NEXT: %subview_13 = memref.subview %subview_9[%10, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_10[%10, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %11 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<4xf32> +#CHECK-NEXT: %17 = vector.extract %13[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<4xf32> +#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %19, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %20 = arith.addi %arg6, %c2 : index +#CHECK-NEXT: %subview_15 = memref.subview %subview_9[%20, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_10[%20, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %21 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %22 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %23 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %24 = vector.extract %22[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %25 = vector.extract %21[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %26 = vector.broadcast %25 : f32 to vector<4xf32> +#CHECK-NEXT: %27 = vector.extract %23[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %28 = vector.fma %26, %24, %27 : vector<4xf32> +#CHECK-NEXT: %29 = vector.insert %28, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %29, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %30 = arith.addi %arg6, %c3 : index +#CHECK-NEXT: %subview_17 = memref.subview %subview_9[%30, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_18 = memref.subview %subview_10[%30, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %31 = vector.transfer_read %subview_17[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %32 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %33 = vector.transfer_read %subview_18[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %34 = vector.extract %32[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %35 = vector.extract %31[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %36 = vector.broadcast %35 : f32 to vector<4xf32> +#CHECK-NEXT: %37 = vector.extract %33[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %38 = vector.fma %36, %34, %37 : vector<4xf32> +#CHECK-NEXT: %39 = vector.insert %38, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %39, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: } {"C/i[1]/i0"} #CHECK-NEXT: } {"C/j0"} #CHECK-NEXT: } {"C/k"} @@ -249,7 +230,7 @@ #CHECK-NEXT: return #CHECK-NEXT: } #CHECK-NEXT: } -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: graph: #CHECK-NEXT: name: matmul #CHECK-NEXT: inputs: @@ -259,5 +240,6 @@ #CHECK-NEXT: - %2 : 16x32xfloat32 #CHECK-NEXT: nodes: #CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT: CODE: 0 + diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py index d9644b627..4e6aedc89 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py @@ -73,99 +73,144 @@ res = executor.execute() print(f"CODE: {res}") -# CHECK: graph: -# CHECK-NEXT: name: matmul -# CHECK-NEXT: inputs: -# CHECK-NEXT: - %0 : 512x512xfloat32 -# CHECK-NEXT: - %1 : 512x512xfloat32 -# CHECK-NEXT: outputs: -# CHECK-NEXT: - %2 : 512x512xfloat32 -# CHECK-NEXT: nodes: -# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] -# CHECK-NEXT: -# CHECK-NEXT:# from tvm.script import ir as I -# CHECK-NEXT:# from tvm.script import tir as T -# CHECK-NEXT: -# CHECK-NEXT:@I.ir_module -# CHECK-NEXT:class Module: -# CHECK-NEXT: @T.prim_func -# CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): -# CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) -# CHECK-NEXT: for i, j in T.grid(512, 512): -# CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) -# CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) -# CHECK-NEXT: for k in range(512): -# CHECK-NEXT: cse_var_2: T.int32 = i * 512 -# CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j -# CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) -# CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) -# CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] -# CHECK-NEXT:INPS = list(obj.values())[:-1] -# CHECK-NEXT:O = obj['C'] -# CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) -# CHECK-NEXT:i, j, = O.op.axis -# CHECK-NEXT:k, = O.op.reduce_axis -# CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) -# CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) -# CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) -# CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) -# CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) -# CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) -# CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1) -# CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) -# CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) -# CHECK-NEXT:sch[O].unroll(__u_k0) -# CHECK-NEXT:sch[O].unroll(i1) -# CHECK-NEXT:sch[O].vectorize(j1) -# CHECK-NEXT:sch[O].parallel(j) -# CHECK-NEXT: -# CHECK-NEXT:# from tvm.script import ir as I -# CHECK-NEXT:# from tvm.script import tir as T -# CHECK-NEXT: -# CHECK-NEXT:@I.ir_module -# CHECK-NEXT:class Module: -# CHECK-NEXT: @T.prim_func -# CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): -# CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) -# CHECK-NEXT: for j_outer in T.parallel(15): -# CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") -# CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) -# CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): -# CHECK-NEXT: for j_inner_inner_init_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): -# CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s] = T.float32(0.0) -# CHECK-NEXT: for j_inner_inner_init_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + j_inner_inner_init_s // 2) // 2 < 128): -# CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + j_inner_inner_init_s + 512] = T.float32(0.0) -# CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): -# CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") -# CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): -# CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) -# CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] -# CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): -# CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) -# CHECK-NEXT: for j_inner_inner_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -# CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 -# CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 -# CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s -# CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + j_inner_inner_s] -# CHECK-NEXT: for j_inner_inner_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -# CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 -# CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 -# CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s + 512 -# CHECK-NEXT: C_1[cse_var_4] = C_1[cse_var_4] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + j_inner_inner_s] -# CHECK-NEXT: for j_inner_inner_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -# CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 -# CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 -# CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s -# CHECK-NEXT: C_1[cse_var_7] = C_1[cse_var_7] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + j_inner_inner_s + 512] -# CHECK-NEXT: for j_inner_inner_s in range(6): -# CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + j_inner_inner_s // 2) // 2 < 128): -# CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 -# CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 -# CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512 -# CHECK-NEXT: C_1[cse_var_10] = C_1[cse_var_10] + _0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17] * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + j_inner_inner_s + 512] -# CHECK-NEXT:CODE: 0 +#CHECK: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 512x512xfloat32 +#CHECK-NEXT: - %1 : 512x512xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 512x512xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for i, j in T.grid(512, 512): +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) +#CHECK-NEXT: for k in range(512): +#CHECK-NEXT: cse_var_2: T.int32 = i * 512 +#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] +#CHECK-NEXT:INPS = list(obj.values())[:-1] +#CHECK-NEXT:O = obj['C'] +#CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) +#CHECK-NEXT:i, j, = O.op.axis +#CHECK-NEXT:k, = O.op.reduce_axis +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) +#CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) +#CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) +#CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) +#CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) +#CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) +#CHECK-NEXT:j1, __v_j1 = sch[O].split(j1, factor=2) +#CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1, __v_j1) +#CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) +#CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) +#CHECK-NEXT:sch[O].unroll(__u_k0) +#CHECK-NEXT:sch[O].unroll(i1) +#CHECK-NEXT:sch[O].unroll(j1) +#CHECK-NEXT:sch[O].vectorize(__v_j1) +#CHECK-NEXT:sch[O].parallel(j) +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for j_outer in T.parallel(15): +#CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + 1) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 127): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 4:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 4 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 512:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 512 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + 1) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 127): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): +#CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") +#CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] +#CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 +#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 2] = C_1[cse_var_1:cse_var_1 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2 +#CHECK-NEXT: C_1[cse_var_4:cse_var_4 + 2] = C_1[cse_var_4:cse_var_4 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4 +#CHECK-NEXT: C_1[cse_var_7:cse_var_7 + 2] = C_1[cse_var_7:cse_var_7 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + 512 +#CHECK-NEXT: C_1[cse_var_10:cse_var_10 + 2] = C_1[cse_var_10:cse_var_10 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11:k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_15: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_14: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_13: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_15 + cse_var_14 + 514 +#CHECK-NEXT: C_1[cse_var_13:cse_var_13 + 2] = C_1[cse_var_13:cse_var_13 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_15 + cse_var_14 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_15 + cse_var_14 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_18: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_17: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_16: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_18 + cse_var_17 + 516 +#CHECK-NEXT: C_1[cse_var_16:cse_var_16 + 2] = C_1[cse_var_16:cse_var_16 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_18 + cse_var_17 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_18 + cse_var_17 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_21: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_20: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_19: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_21 + cse_var_20 +#CHECK-NEXT: C_1[cse_var_19:cse_var_19 + 2] = C_1[cse_var_19:cse_var_19 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_21 + cse_var_20 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_21 + cse_var_20 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_24: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_23: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_22: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_24 + cse_var_23 + 2 +#CHECK-NEXT: C_1[cse_var_22:cse_var_22 + 2] = C_1[cse_var_22:cse_var_22 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_27: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_26: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_25: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_27 + cse_var_26 + 4 +#CHECK-NEXT: C_1[cse_var_25:cse_var_25 + 2] = C_1[cse_var_25:cse_var_25 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_27 + cse_var_26 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_27 + cse_var_26 + 516 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_30: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_29: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_28: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512 +#CHECK-NEXT: C_1[cse_var_28:cse_var_28 + 2] = C_1[cse_var_28:cse_var_28 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_33: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_32: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_31: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514 +#CHECK-NEXT: C_1[cse_var_31:cse_var_31 + 2] = C_1[cse_var_31:cse_var_31 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_36: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_35: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_34: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 +#CHECK-NEXT: C_1[cse_var_34:cse_var_34 + 2] = C_1[cse_var_34:cse_var_34 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 + 2] +#CHECK:CODE: 0 + diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py index 652e0f329..ad7dd95ae 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py @@ -51,7 +51,7 @@ res = executor.execute() print(f"CODE: {res}") -#CHECK: graph: +#CHECK:graph: #CHECK-NEXT: name: matmul #CHECK-NEXT: inputs: #CHECK-NEXT: - %0 : 4x512xfloat32 @@ -60,10 +60,10 @@ #CHECK-NEXT: - %2 : 4x32xfloat32 #CHECK-NEXT: nodes: #CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT:# from tvm.script import ir as I #CHECK-NEXT:# from tvm.script import tir as T -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT:@I.ir_module #CHECK-NEXT:class Module: #CHECK-NEXT: @T.prim_func @@ -85,10 +85,10 @@ #CHECK-NEXT:sch[O].reorder(k, i, j, j0, i0) #CHECK-NEXT:sch[O].unroll(i0) #CHECK-NEXT:sch[O].vectorize(j0) -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT:# from tvm.script import ir as I #CHECK-NEXT:# from tvm.script import tir as T -#CHECK-NEXT: +#CHECK-EMPTY: #CHECK-NEXT:@I.ir_module #CHECK-NEXT:class Module: #CHECK-NEXT: @T.prim_func diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py index cdbab3779..0b1049662 100644 --- a/tests/filecheck/search/test_matmul_descript_3axes.py +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -26,4 +26,4 @@ print(strategy._constraints) -# CHECK: ['1 || kR || 12', '1 || jR || 32', '1 || iR || 21', '0 <= order_DDR <= 5', '0 <= order_R <= 5'] +# CHECK: ['iR || {21}', 'jR || {32}', 'kR || {12}', 'order_DDR in {0, 1, 2, 3, 4, 5}', 'order_R in {0, 1, 2, 3, 4, 5}'] diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py index fd97325fb..ba6f16bf5 100644 --- a/tests/filecheck/search/test_matmul_descript_goto.py +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -34,4 +34,4 @@ print(strategy._constraints) -# CHECK: ['1 || k_unroll || kL1 || 12', '1 || jR || jL3 || 32', '1 || iR || iL2 || 21', '0 <= pack_B <= 1', '0 <= pack_A <= 1', '0 <= j_parallel <= 1', '0 <= j_vectorise <= 1', 'iR * jR <= 56', '0 <= order_DDR <= 1'] +# CHECK: ['iL2 || {21}', 'iR * jR <= 56', 'iR || {iL2, 21}', 'jL3 || {32}', 'jR || {jL3, 32}', 'j_parallel in {0, 1}', 'j_vectorise in {0, 1}', 'kL1 || {12}', 'k_unroll || kL1', 'order_DDR in {0, 1}', 'pack_A in {0, 1}', 'pack_B in {0, 1}'] diff --git a/tests/filecheck/search/test_matmul_descript_simple.py b/tests/filecheck/search/test_matmul_descript_simple.py index 6d715109a..039f8d40a 100644 --- a/tests/filecheck/search/test_matmul_descript_simple.py +++ b/tests/filecheck/search/test_matmul_descript_simple.py @@ -24,4 +24,4 @@ print(strategy._constraints) -# CHECK: ['1 || j2 || j1 || 32', '1 || i1 || 21'] +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {j1, 32}'] diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py index 703b30ac3..735583ef7 100644 --- a/tests/filecheck/search/test_matmul_descript_split.py +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -38,4 +38,4 @@ print(strategy._constraints) -# CHECK: ['1 || jR1 || jDDR || 32', '1 || jR3 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || 21', '1 || iR1 || 2', '1 || iR3 || 1', '1 || iR2 || i_1_', 'i_1_ + 6 == iL2'] +# CHECK: ['iL3 || {21}', 'iR1 || {7, iL3, 21}', 'iR2 || {7, iL3, 21}', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}'] diff --git a/tests/filecheck/search/test_matmul_descript_split_in_split.py b/tests/filecheck/search/test_matmul_descript_split_in_split.py index e397cc9ef..1e4942c4f 100644 --- a/tests/filecheck/search/test_matmul_descript_split_in_split.py +++ b/tests/filecheck/search/test_matmul_descript_split_in_split.py @@ -43,4 +43,4 @@ print(strategy._constraints) -# CHECK: ['1 || jR2 || jDDR || 32', '1 || jR1 || jDDR || 32', '1 || iR2 || 2', '1 || iR1 || 5', '7 || iL3 || 21'] +# CHECK: ['iL2 || {21}', 'iR1 || {3, iL2, 21}', 'iR2 || {iL2, 21}', 'iR3 || {3, iL2, 21}', 'iS + 2 == 3', 'i_1_ + 6 == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}', 'jR3 || {jDDR, 32}'] diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py index 0205371aa..28d121e36 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_goto.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -62,5 +62,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['nc || {1024}', 'mc || {1024}', 'kc || {1024}', 'kr || kc', 'mr || {mc, 1024}', 'nr || {nc, 1024}', '1 + nvr + nvr * mr <= 32', 'nr == 16 * nvr', 'nvr * mr >= 8', 'nvr * mr * kr <= 256', 'kc * nr <= 8192', 'kc * mc <= 262144', 'kc * nc <= 9437184'] +# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc || {1024}', 'kr || kc', 'mc || {1024}', 'mr || {mc, 1024}', 'nc || {1024}', 'nr == 16 * nvr', 'nr || {nc, 1024}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] #CHECK-NEXT: 100 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_simple.py b/tests/filecheck/search/test_matmul_descript_yaml_simple.py index 9112a6321..ee6ef4713 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_simple.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_simple.py @@ -24,5 +24,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['1 || j2 || j1 || 32', '1 || i1 || 21'] +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {j1, 32}'] # CHECK-NEXT: 84 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_split.py b/tests/filecheck/search/test_matmul_descript_yaml_split.py index c054cd4e6..82d441285 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_split.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_split.py @@ -35,5 +35,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['1 || SR || 12', '1 || jR1 || jDDR || 32', '1 || jR2 || jDDR || 32', '1 || iL2 || iL3 || 21', '1 || iR1 || iS', '1 || iR2 || i_1_', 'iS <= iL2', 'i_1_ + iS == iL2'] +# CHECK: ['SR || {12}', 'iL2 || {iL3, 21}', 'iL3 || {21}', 'iR1 || {iL2, iL3, 21}', 'iR2 || {iL2, iL3, 21}', 'iS <= iL2', 'i_1_ + iS == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}'] # CHECK-NEXT: 100 From c25f25c72f8f7b699ccd09d45d5437bde4c445b3 Mon Sep 17 00:00:00 2001 From: Leon Frenot Date: Fri, 16 Jan 2026 10:50:52 +0100 Subject: [PATCH 22/23] Partial Goto --- src/xtc/schedules/descript_extend.py | 18 ------------------ .../search/test_matmul_descript_yaml_goto.py | 4 ++-- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 15ebb4c16..534a5b669 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -289,22 +289,6 @@ def _flatten_schedule( ) -> LoopNestExtend: recursive_scheds = LoopNestExtend(abstract_dims=self.abstract_axis) sched = recursive_scheds.build_slice(root) - # sched: SchedDict = { - # "root": root, - # "fusions": {}, - # "packs": {}, - # "buffers": {}, - # "axis_orders": [], - # "axes": {}, - # "splits": {}, - # "tiles": {a: {} for a in self.abstract_axis}, - # "interchange": [], - # "vectorize": [], - # "parallelize": [], - # "unroll": {}, - # "variables": [], - # "constraints": [], - # } # State of the schedule if tile_sizes: axes_sizes: dict[str, int | str] = tile_sizes @@ -317,8 +301,6 @@ def _flatten_schedule( sizes: dict[str, int | str | None] = {} previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = head - # constraints: list[str] = [] - # variables: list[str] = [] # Processing the schedule for tree_declaration, tree_val in spec.items(): assert isinstance(tree_val, dict) diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py index 28d121e36..ba6070d76 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_goto.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -57,10 +57,10 @@ f"kc * mc <= {nb_words_L2}", f"kc * nc <= {nb_words_L3}", ] -strategy = Strategy(graph, spec, constraints=constraints, partial_tiles=False, partial_unrolls=False, initialize=False) +strategy = Strategy(graph, spec, constraints=constraints, partial_tiles=True, partial_unrolls=True, initialize=False) print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc || {1024}', 'kr || kc', 'mc || {1024}', 'mr || {mc, 1024}', 'nc || {1024}', 'nr == 16 * nvr', 'nr || {nc, 1024}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] +# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc <= 1024', 'kr <= kc', 'mc <= 1024', 'mr || {mc, 1024}', 'nc <= 1024', 'nr == 16 * nvr', 'nr || {nc, 1024}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] #CHECK-NEXT: 100 From d254ce656384a24d680086142d6d3195d1820675 Mon Sep 17 00:00:00 2001 From: Leon Frenot Date: Wed, 21 Jan 2026 13:39:58 +0100 Subject: [PATCH 23/23] Fixes after more rebasing --- requirements.txt | 1 + src/xtc/cli/mlir_loop.py | 6 +- src/xtc/schedules/descript.py | 194 ++- src/xtc/schedules/descript_extend.py | 1096 +++++++++-------- src/xtc/search/strategies.py | 42 +- .../splitting/v_splitting_extend.mlir | 51 +- .../tiling/i_invalide_argument.mlir | 3 +- .../i_one_axis_positive_negative_tiling.mlir | 2 +- ...test_matmul_descript_extend_mlir_sample.py | 4 - .../test_matmul_descript_extend_mlir_split.py | 154 ++- .../test_matmul_descript_extend_tvm_goto.py | 167 ++- ...est_matmul_descript_extend_tvm_strategy.py | 8 +- .../search/test_matmul_descript_3axes.py | 8 +- .../search/test_matmul_descript_goto.py | 18 +- .../search/test_matmul_descript_simple.py | 21 +- .../search/test_matmul_descript_split.py | 22 +- .../test_matmul_descript_split_in_split.py | 28 +- .../search/test_matmul_descript_yaml_goto.py | 7 +- .../test_matmul_descript_yaml_simple.py | 5 +- .../search/test_matmul_descript_yaml_split.py | 9 +- 20 files changed, 933 insertions(+), 913 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6b3d57311..e3fdb8522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ scikit-learn networkx sympy strictyaml +types-PyYAML diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index d15fe362f..0d4eb71f3 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -182,11 +182,13 @@ def normalize_extend_schedule( assert isinstance(instr, str) if isinstance(param, builtin.UnitAttr): annotations[instr] = None - elif isinstance(param, builtin.IntegerAttr): + elif isinstance(param, builtin.IntegerAttr) or isinstance( + param, builtin.StringAttr + ): annotations[instr] = param.value.data else: raise Exception( - "Annotation parameter should be void or int." + "Annotation parameter should be void, int, or str." ) elif not isinstance(val, builtin.UnitAttr): diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 7f1873987..62297f726 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -41,10 +41,12 @@ class Annotations: parallelize: True if parallelization was requested. """ - unroll_factor: int | None = None + unroll_factor: int | str | None = None unroll_specified: bool = False - vectorize: bool = False - parallelize: bool = False + vectorize: bool | str = False + parallelize: bool | str = False + partial: bool = False + full: bool = False @dataclass(frozen=True) @@ -52,12 +54,15 @@ class SplitDecl: """AST Type: a split declaration like 'axis[start:end]'.""" axis: str - start: int | None - end: int | None + start: int | str | None + end: int | str | None body: ScheduleSpec + size: int | str | None = None @override def __str__(self) -> str: + if self.size is not None: + return f"{self.axis}[:{self.size}:]" start_str = "" if self.start is None else str(self.start) end_str = "" if self.end is None else str(self.end) decl = f"{self.axis}[{start_str}:{end_str}]" @@ -69,7 +74,7 @@ class TileDecl: """AST Type: a tile declaration like 'axis#size'.""" axis: str - size: int + size: int | str annotations: Annotations @override @@ -85,7 +90,36 @@ class AxisDecl: annotations: Annotations -ScheduleItem = SplitDecl | TileDecl | AxisDecl +@dataclass(frozen=True) +class FusionDecl: + """AST Type: a fusion declaration""" + + +@dataclass(frozen=True) +class PackDecl: + """AST Type: a packing declaration""" + + param: str | bool + input: str + pad: str | bool + + +@dataclass(frozen=True) +class BufferDecl: + """AST Type: a bufferisation declaration""" + + param: str | bool + pad: str + + +@dataclass(frozen=True) +class ExploreDecl: + level: str + + +ScheduleItem = ( + SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl +) @dataclass(frozen=True) @@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl: axis_name, size_str = parts - try: - size = int(size_str) - except ValueError: - raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") + size = int(size_str) if size_str.isnumeric() else size_str + + # try: + # size = int(size_str) + # except ValueError: + # raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") annotations = self._parse_annotations(value, declaration) return TileDecl(axis=axis_name, size=size, annotations=annotations) @@ -234,8 +270,8 @@ def _interpret_spec( slice = loop_nest.build_slice(root) # Track state during interpretation - sizes: dict[str, int] = {} - previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} + sizes: dict[str, int | str] = {} + previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = list(head) for item in spec.items: @@ -267,7 +303,7 @@ def _interpret_split( loop_nest: LoopNest, root: str, interchange: list[str], - previous_cut: dict[str, int | None], + previous_cut: dict[str, int | str | None], ) -> None: """Interpret a split declaration.""" axis_name = item.axis @@ -283,10 +319,8 @@ def _interpret_split( # it is the previous cut if x is None: x = cut - assert x is not None - self._check_splitting_intervals(item, cut, x) - + assert x is not None # Update the previous cut previous_cut[axis_name] = y @@ -308,12 +342,15 @@ def _interpret_tile( item: TileDecl, slice: LoopNestSlice, interchange: list[str], - sizes: dict[str, int], + sizes: dict[str, int | str], ) -> str: """Interpret a tile declaration. Returns the loop name.""" self._check_axis_existence(item.axis) tile_num = len(slice.tiles[item.axis]) loop_name = f"{item.axis}{tile_num}" + if not isinstance(item.size, int): + raise ScheduleInterpretError(f"`{item}`: {item.size} is not an integer.") + assert isinstance(item.size, int) if item.size <= 0: raise ScheduleInterpretError( f"`{item}`: tile sizes should be strictly positive." @@ -354,7 +391,7 @@ def _apply_annotations( self, annotations: Annotations, loop_name: str, - sizes: dict[str, int], + sizes: dict[str, int | str], slice: LoopNestSlice, ) -> None: """Apply annotations to a loop in the slice.""" @@ -367,7 +404,7 @@ def _apply_annotations( f"{loop_name}'s size being unknown, an unroll factor is needed." ) unroll_factor = sizes[loop_name] - elif unroll_factor <= 0: + elif isinstance(unroll_factor, int) and unroll_factor <= 0: raise ScheduleInterpretError( f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' ) @@ -382,27 +419,46 @@ def _apply_annotations( def _check_splitting_intervals( self, item: SplitDecl, - cut: int | None, - x: int, - ) -> None: + cut: int | str | None, + x: int | str | None, + ) -> int | str | None: """Check that split intervals are valid and contiguous.""" - + y = item.end if cut is None: raise ScheduleInterpretError(f"{item}: {item.axis} already covered.") - if x > cut: - raise ScheduleInterpretError( - f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." - ) - elif x < cut: + if x is None: raise ScheduleInterpretError( - f"{item}: the segment begins at {x} but the previous one ends at {cut}." + f"x is None, but cut: {cut} is not, this should be unreachable." ) + if isinstance(x, int) and isinstance(cut, int): + if x > cut: + raise ScheduleInterpretError( + f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." + ) + elif x < cut: + raise ScheduleInterpretError( + f"{item}: the segment begins at {x} but the previous one ends at {cut}." + ) + else: + if x != cut: + raise ScheduleInterpretError( + f"{item}: Splitting ends at {cut} and begins at {x}. These need to be the same." + ) + if y is None: + return None - if item.end is not None and x >= item.end: - raise ScheduleInterpretError( - f"{item}: the ending point should be greater than the starting point." - ) + if isinstance(x, int): + if isinstance(y, int): + if x >= y: + raise ScheduleInterpretError( + f"{item}: the ending point should be greater than the starting point." + ) + else: + return y - x + if x == 0: + return y + return None @dataclass @@ -506,6 +562,34 @@ def tiles_to_sizes(self) -> dict[str, int]: tiles_to_sizes[loop] = size return tiles_to_sizes + @property + def int_tiles(self) -> dict[str, dict[str, int]]: + return self._int_dict(self.tiles) + + @property + def int_splits(self) -> dict[str, dict[str, int]]: + return self._int_dict(self.splits) + + @property + def int_unroll(self) -> dict[str, int]: + out = {} + for x, v in self.unroll.items(): + if isinstance(v, str) and v.isnumeric(): + v = int(v) + assert isinstance(v, int) + out[x] = v + return out + + def _int_dict(self, input: dict[str, dict[str, Any]]) -> dict[str, dict[str, int]]: + out: dict[str, dict[str, int]] = {} + for x, v in input.items(): + v_dict: dict[str, int] = {} + for x_v, v_v in v.items(): + assert isinstance(v_v, int) + v_dict[x_v] = v_v + out[x] = v_dict + return out + @dataclass class LoopNest: @@ -616,6 +700,7 @@ def _check_sizes(self): if loop_name in sched.unroll: unroll_factor = sched.unroll[loop_name] + assert isinstance(unroll_factor, int) if loop_size and loop_size < unroll_factor: raise ScheduleValidationError( f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' @@ -650,19 +735,11 @@ def descript_scheduler( abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]). spec: The schedule specification as a nested dict. """ - descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis) - descript.apply(node_name=node_name, spec=spec) - - -def correct_type(d: dict[str, int | str]) -> dict[str, int]: - out_d: dict[str, int] = {} - for k, v in d.items(): - assert isinstance(v, int) - out_d[k] = v - return out_d + descript = Descript(abstract_axis=abstract_axis) + descript.apply(scheduler=scheduler, node_name=node_name, spec=spec) -@dataclass(frozen=True) +@dataclass(frozen=False) class Descript: """Applies a parsed and interpreted schedule to a Scheduler. @@ -674,10 +751,11 @@ class Descript: 4. Apply: LoopNest -> Scheduler """ - scheduler: Scheduler abstract_axis: list[str] - def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: + def apply( + self, node_name: str, spec: dict[str, dict[str, Any]], scheduler: Scheduler + ) -> None: """Parse, interpret, validate, and apply a schedule specification. Args: @@ -701,22 +779,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: loop_nest.check() # Apply the schedule to the scheduler - self._apply_loop_nest(loop_nest) + self._apply_loop_nest(loop_nest, scheduler) - def _apply_loop_nest(self, loop_nest: LoopNest) -> None: + def _apply_loop_nest(self, loop_nest: LoopNest, scheduler: Scheduler) -> None: """Apply a LoopNest to the scheduler.""" - self.scheduler.set_dims(self.abstract_axis) + scheduler.set_dims(self.abstract_axis) for slice in loop_nest.slices: root = slice.root - for d, s in slice.splits.items(): - self.scheduler.split(d, s, root=root) + for d, s in slice.int_splits.items(): + scheduler.split(d, s, root=root) - for d, s in slice.tiles.items(): - self.scheduler.tile(d, s, root=root) + for d, s in slice.int_tiles.items(): + scheduler.tile(d, s, root=root) - self.scheduler.interchange(slice.interchange, root=root) - self.scheduler.vectorize(slice.vectorize, root=root) - self.scheduler.parallelize(slice.parallelize, root=root) - self.scheduler.unroll(slice.unroll, root=root) + scheduler.interchange(slice.interchange, root=root) + scheduler.vectorize(slice.vectorize, root=root) + scheduler.parallelize(slice.parallelize, root=root) + scheduler.unroll(slice.int_unroll, root=root) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py index 534a5b669..a443eba82 100644 --- a/src/xtc/schedules/descript_extend.py +++ b/src/xtc/schedules/descript_extend.py @@ -2,16 +2,34 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing import Any, Tuple +from typing import Any from copy import deepcopy from dataclasses import dataclass, field import re + import strictyaml from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler -from xtc.schedules.descript import Descript, LoopNest, LoopNestSlice, correct_type +from xtc.schedules.descript import ( + Annotations, + AxisDecl, + BufferDecl, + Descript, + FusionDecl, + LoopNest, + LoopNestSlice, + PackDecl, + ScheduleInterpretError, + ScheduleInterpreter, + ScheduleItem, + ScheduleParseError, + ScheduleParser, + ScheduleSpec, + SplitDecl, + TileDecl, +) @dataclass @@ -121,38 +139,542 @@ def descript_extend_scheduler( partial_tiles=partial_tiles, partial_unrolls=partial_unrolls, ) - descript.apply(node_name=node_name, spec=spec, scheduler=scheduler, sample=sample) + descript.apply(node_name=node_name, spec=spec, sample=sample, scheduler=scheduler) + + +class ScheduleParserExtend(ScheduleParser): + _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\w+|\w*)?:(-\w+|\w*)?\]$") + _SPLIT_MIDDLE_PATTERN = re.compile(r"^(.*)\[:(\w*):\]$") + + @override + def __init__(self, abstract_axis: list[str], abstract_matrix: list[str]): + self.abstract_matrix = abstract_matrix + super().__init__(abstract_axis) + + @override + def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: + if "fusion" == declaration: + return self._parse_fusion() + if "pack" == declaration: + return self._parse_pack(value) + if "buffer" == declaration: + return self._parse_buffer(value) + if declaration in self.abstract_matrix: + return self._parse_matrix(declaration, value) + + return super()._parse_declaration(declaration, value) + + @override + def _parse_split(self, declaration: str, value: dict) -> SplitDecl: + axis_name, start, end, size = self._parse_split_syntax_extend(declaration) + + body = self.parse(value) + return SplitDecl(axis=axis_name, start=start, end=end, body=body, size=size) + + @override + def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: + """Parse annotation dict into Annotations object.""" + + unroll_factor: int | str | None = None + unroll_specified = False + vectorize = False + parallelize = False + partial = False + full = False + + for key, param in value.items(): + if key == "unroll": + if param is True or param is None: + unroll_factor = None + unroll_specified = True + elif param is False: + pass + elif isinstance(param, int) or isinstance(param, str): + unroll_factor = param + unroll_specified = True + else: + raise ScheduleParseError( + f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, None, or an integer.' + ) + elif key == "vectorize": + if param is True or param is None: + vectorize = True + elif param is False: + pass + elif isinstance(param, str): + vectorize = param + else: + raise ScheduleParseError( + f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' + ) + elif key == "parallelize": + if isinstance(param, str): + parallelize = param + elif param is not None: + raise ScheduleParseError( + f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' + ) + else: + parallelize = True + elif key == "partial": + if full: + raise ScheduleParseError("Tile cannot be full and partial.") + partial = True + elif key == "full": + if partial: + raise ScheduleParseError("Tile cannot be partial and full.") + full = True + else: + raise ScheduleParseError(f"Unknown annotation on {context}: {key}") + + return Annotations( + unroll_factor=unroll_factor, + unroll_specified=unroll_specified, + vectorize=vectorize, + parallelize=parallelize, + partial=partial, + full=full, + ) + + def _parse_split_syntax_extend( + self, declaration: str + ) -> tuple[str, int | str | None, int | str | None, int | str | None]: + """Parse the syntax of a split declaration.""" + match = self._SPLIT_PATTERN.match(declaration) + if not match: + match = self._SPLIT_MIDDLE_PATTERN.match(declaration) + if not match: + raise ScheduleParseError(f"Wrong format {declaration}") + prefix, z = match.groups() + z = int(z) if z.isnumeric() else z + return prefix, None, None, z + + prefix, x_str, y_str = match.groups() + x = int(x_str) if x_str.isnumeric() else x_str + y = int(y_str) if y_str.isnumeric() else y_str + x = x if x else None + y = y if y else None + return prefix, x, y, None + + def _parse_fusion(self) -> FusionDecl: + return FusionDecl() + + def _parse_pack(self, value: Any) -> PackDecl: + assert len(value) == 3 + param, input, pad = value + return PackDecl(param, input, pad) + + def _parse_buffer(self, value: Any) -> BufferDecl: + assert len(value == 2) + param, pad = value + return BufferDecl(param, pad) + + def _parse_matrix(self, declaration: str, value: Any) -> PackDecl | BufferDecl: + param = value.get("bufferize", False) + if not (param is None or param): + raise ScheduleParseError( + f"Declared matrix {declaration} without bufferization." + ) + pad = value.get("pad", False) + if declaration == self.abstract_matrix[-1]: + return BufferDecl(param, pad) + return PackDecl(param, declaration, pad) + + +class ScheduleInterpreterExtend(ScheduleInterpreter): + @override + def __init__( + self, + abstract_axis: list[str], + abstract_axis_sizes: dict[str, int], + abstract_matrix: list[str], + partial_tiles: bool = False, + partial_unrolls: bool = False, + ): + self.abstract_matrix = abstract_matrix + self.abstract_axis_sizes = abstract_axis_sizes + self.partial_tiles = partial_tiles + self.partial_unrolls = partial_unrolls + super().__init__(abstract_axis) + + @override + def interpret(self, spec: ScheduleSpec, root: str) -> LoopNestExtend: + return self._interpret_spec(spec, root, head=[]) + + @override + def _interpret_spec( + self, + spec: ScheduleSpec, + root: str, + head: list[str], + tile_sizes: dict[str, list[int | str]] | None = None, + ) -> LoopNestExtend: + """Interpret a schedule spec recursively.""" + loop_nest = LoopNestExtend(abstract_dims=self.abstract_axis) + slice = loop_nest.build_slice(root) + + # Track state during interpretation + last_split: list[tuple[int | str, int | str]] = [] + previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} + interchange: list[str] = list(head) + axes_sizes: dict[str, list[int | str]] = {} + sizes: dict[str, int | str] = {} + if tile_sizes: + axes_sizes = tile_sizes + else: + axes_sizes = {a: [v] for a, v in self.abstract_axis_sizes.items()} + + for item in spec.items: + if isinstance(item, SplitDecl): + self._interpret_split( + item=item, + slice=slice, + loop_nest=loop_nest, + root=root, + interchange=interchange, + previous_cut=previous_cut, + axes_sizes=axes_sizes, + last_split=last_split, + ) + elif isinstance(item, TileDecl): + loop_name = self._interpret_tile( + item=item, + slice=slice, + interchange=interchange, + axes_sizes=axes_sizes, + sizes=sizes, + ) + self._apply_annotations(item.annotations, loop_name, sizes, slice) + elif isinstance(item, AxisDecl): + loop_name = self._interpret_axis(item, interchange) + self._apply_annotations(item.annotations, loop_name, sizes, slice) + elif isinstance(item, FusionDecl): + ... + elif isinstance(item, PackDecl): + self._interpret_pack(slice, interchange[-1], item) + elif isinstance(item, BufferDecl): + self._interpret_buffer(slice, interchange[-1], item) + + if len(last_split) > 0: + a, b = last_split[0] + if isinstance(a, int) and not isinstance(b, int): + a, b = b, a + a, b = str(a), str(b) + for c in slice.constraints: + slice.constraints.remove(c) + slice.constraints.add(c.replace(a, b)) + + # Check that all splits are complete + for axis, cut in previous_cut.items(): + if ( + cut is not None + and isinstance(cut, int) + and cut not in [0, axes_sizes[axis][-1]] + ): + raise ScheduleInterpretError( + f"Splitting of {axis} unachieved (stops at {cut})." + ) + + slice.interchange = interchange + return loop_nest + + @override + def _interpret_split( + self, + item: SplitDecl, + slice: LoopNestSlice, + loop_nest: LoopNest, + root: str, + interchange: list[str], + previous_cut: dict[str, int | str | None], + axes_sizes: dict[str, list[int | str]] = {}, + last_split: list[tuple[int | str, int | str]] = [], + ): + """Interpret a split declaration.""" + if not isinstance(slice, LoopNestSliceExtend): + return super()._interpret_split( + item, slice, loop_nest, root, interchange, previous_cut + ) + axis_name = item.axis + self._check_axis_existence(axis_name) + x = item.start + y = item.end + z = item.size + + # The only declaration where y (the cut) is None is the + # last one, so it cannot be the previous one. + cut = previous_cut[axis_name] + current_size = axes_sizes[axis_name][-1] + + # When x (the starting point of the slice) is not specified, + # it is the previous cut + if x is None: + x = cut + inner_size = self._check_splitting_intervals(item, cut, x) + assert x is not None + + # Save the cutting points of the new dimensions + if axis_name not in slice.splits: + slice.splits[axis_name] = {} + new_dim_index = len(slice.splits[axis_name]) + new_dim_name = f"{axis_name}[{new_dim_index}]" + new_root_name = f"{root}/{new_dim_name}" + interchange.append(new_dim_name) + + if z is None: + # Update the previous cut + previous_cut[axis_name] = y + slice.splits[axis_name][new_dim_name] = x + inner_size = None + if y is None: + y = current_size + if isinstance(x, int): + if x == 0: + inner_size = y + elif isinstance(y, int): + inner_size = y - x + if inner_size is None: + inner_size = root[1:] + new_dim_name + inner_size = ( + inner_size.replace("/", "").replace("[", "_").replace("]", "_") + ) + slice.constraints.add(f"{inner_size} <= {y}") + if isinstance(x, str): + slice.constraints.add(f"{x} <= {y}") + slice.constraints.add(f"{inner_size} + {x} == {y}") + else: + inner_size = z + x = cut + y = current_size + assert x is not None + slice.splits[axis_name][new_dim_name] = x + if isinstance(z, int) and isinstance(x, int): + previous_cut[axis_name] = x + z + if not isinstance(y, int): + slice.constraints.add(f"{z + x} <= {y}") + elif isinstance(x, int) and x == 0: + previous_cut[axis_name] = z + if not isinstance(y, int): + slice.constraints.add(f"{z} <= {y}") + else: + new_cut = root[1:] + new_dim_name + new_cut = new_cut.replace("/", "").replace("[", "_").replace("]", "_") + previous_cut[axis_name] = new_cut + if len(last_split) > 0: + a, b = last_split[0] + slice.constraints.add(f"{a} <= {b}") + last_split.append((new_cut, y)) + slice.constraints.add(f"{z} + {x} == {new_cut}") + + # Recursively interpret the nested schedule + inner_nest = self._interpret_spec( + spec=item.body, + root=new_root_name, + head=[axis_name], + tile_sizes=deepcopy(axes_sizes), + ) + loop_nest.slices += inner_nest.slices + + @override + def _interpret_tile( + self, + item: TileDecl, + slice: LoopNestSlice, + interchange: list[str], + sizes: dict[str, int | str], + axes_sizes: dict[str, list[int | str]] = {}, + ) -> str: + """Interpret a tile declaration. Returns the loop name.""" + self._check_axis_existence(item.axis) + tile_num = len(slice.tiles[item.axis]) + loop_name = f"{item.axis}{tile_num}" + if isinstance(item.size, int) and item.size <= 0: + raise ScheduleInterpretError( + f"`{item}`: tile sizes should be strictly positive." + ) + slice.tiles[item.axis][loop_name] = item.size + size_list = axes_sizes[item.axis] + sizes[loop_name] = item.size + assert isinstance(size_list, list) + old_size = size_list[-1] + interchange.append(loop_name) + if isinstance(item.size, str): + assert isinstance(slice, LoopNestSliceExtend) + slice.variables.add(item.size) + partial = item.annotations.partial + full = item.annotations.full + if partial or (not full and self.partial_tiles): + slice.constraints.add(f"{item.size} <= {old_size}") + else: + s = ( + ", ".join(map(str, size_list)) + if len(size_list) > 1 + else str(size_list[0]) + ) + s = f"{item.size} || {{{s}}}" + slice.constraints.add(s) + size_list.append(item.size) + return loop_name + + @override + def _apply_annotations( + self, + annotations: Annotations, + loop_name: str, + sizes: dict[str, int | str], + slice: LoopNestSlice, + ) -> None: + """Apply annotations to a loop in the slice.""" + assert isinstance(slice, LoopNestSliceExtend) + if annotations.unroll_specified: + unroll_factor = annotations.unroll_factor + if unroll_factor is None: + # None means "unroll fully" - use the loop size + if loop_name not in sizes: + raise ScheduleInterpretError( + f"{loop_name}'s size being unknown, an unroll factor is needed." + ) + unroll_factor = sizes[loop_name] + elif isinstance(unroll_factor, str): + slice.variables.add(unroll_factor) + if self.partial_unrolls: + slice.constraints.add(f"{unroll_factor} <= {sizes[loop_name]}") + else: + slice.constraints.add(f"{unroll_factor} || {sizes[loop_name]}") + elif unroll_factor <= 0: + raise ScheduleInterpretError( + f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' + ) + slice.unroll[loop_name] = unroll_factor + + vectorize = annotations.vectorize + if vectorize: + slice.vectorize.append(loop_name) + if isinstance(vectorize, str): + slice.variables.add(vectorize) + slice.constraints.add(f"{vectorize} in {{0, 1}}") + + parallelize = annotations.parallelize + if parallelize: + slice.parallelize.append(loop_name) + if isinstance(parallelize, str): + slice.variables.add(parallelize) + slice.constraints.add(f"{parallelize} in {{0, 1}}") + + def _interpret_pack( + self, slice: LoopNestSliceExtend, loop_name: str, item: PackDecl + ): + param, input, pad = item.param, item.input, item.pad + if isinstance(param, str): + slice.variables.add(param) + slice.constraints.add(f"{param} in {{0, 1}}") + if isinstance(pad, str): + slice.variables.add(pad) + slice.constraints.add(f"{pad} in {{0, 1}}") + if loop_name in slice.packs: + slice.packs[loop_name].append((param, input, pad)) + else: + slice.packs[loop_name] = [(param, input, pad)] + + def _interpret_buffer( + self, slice: LoopNestSliceExtend, loop_name: str, item: BufferDecl + ): + param, pad = item.param, item.pad + if isinstance(param, str): + slice.variables.add(param) + slice.constraints.add(f"{param} in {{0, 1}}") + if isinstance(pad, str): + slice.variables.add(pad) + slice.constraints.add(f"{pad} in {{0, 1}}") + slice.buffers[loop_name].append((param, pad)) -@dataclass(frozen=True) +@dataclass(frozen=False) class DescriptExtend(Descript): abstract_axis_sizes: dict[str, int] abstract_matrix: list[str] partial_tiles: bool = False partial_unrolls: bool = False + _loop_nest: None | LoopNestExtend = None @override def apply( self, node_name: str, - spec: dict[str, dict] | str, + spec: str | dict[str, dict[str, Any]], scheduler: Scheduler, sample: dict[str, Any] = {}, - ): + ) -> None: + """Parse, interpret, validate, and apply a schedule specification. + + Args: + node_name: The name of the root node to schedule. + spec: The schedule specification as a nested dict. + Raises: + ScheduleParseError: If the spec cannot be parsed. + ScheduleInterpretError: If the spec cannot be interpreted. + ScheduleValidationError: If the resulting schedule is invalid. + """ + if isinstance(spec, str): - dict_spec = self.parse_yaml(spec) - else: - dict_spec = spec - flat_schedules = self._flatten_schedule(root=node_name, spec=dict_spec, head=[]) - variables = set() - constraints = set() - for schedule in flat_schedules.slices: - if isinstance(schedule, LoopNestSliceExtend): - variables.update(schedule.variables) - constraints.update(schedule.constraints) + spec = self.parse_yaml(spec) + + # Parse the specification into an AST + parser = ScheduleParserExtend(self.abstract_axis, self.abstract_matrix) + ast = parser.parse(spec) + + # Interpret the AST into a LoopNest + interpreter = ScheduleInterpreterExtend( + self.abstract_axis, self.abstract_axis_sizes, self.abstract_matrix + ) + loop_nest = interpreter.interpret(ast, root=node_name) + + if sample != {}: + loop_nest.apply_sample(sample) + + # Validate the loop nest + loop_nest.check() + for slice in loop_nest.slices: + assert isinstance(slice, LoopNestSliceExtend) + + # Apply the schedule to the scheduler + self._apply_loop_nest(loop_nest, scheduler) + + def loop_nest( + self, node_name: str, spec: str | dict[str, dict[str, Any]] + ) -> LoopNestExtend: + if self._loop_nest: + return self._loop_nest + + if isinstance(spec, str): + spec = self.parse_yaml(spec) + + parser = ScheduleParserExtend(self.abstract_axis, self.abstract_matrix) + ast = parser.parse(spec) + + # Interpret the AST into a LoopNest + interpreter = ScheduleInterpreterExtend( + abstract_axis=self.abstract_axis, + abstract_axis_sizes=self.abstract_axis_sizes, + abstract_matrix=self.abstract_matrix, + partial_tiles=self.partial_tiles, + partial_unrolls=self.partial_unrolls, + ) + self._loop_nest = interpreter.interpret(ast, root=node_name) + return self._loop_nest - flat_schedules = self.apply_sample(flat_schedules, sample) - self.apply_scheduler(flat_schedules, scheduler) + def apply_sample( + self, loop_nest: LoopNestExtend, scheduler: Scheduler, sample: dict[str, Any] + ): + loop_nest = deepcopy(loop_nest) + if sample != {}: + loop_nest.apply_sample(sample) + + # Validate the loop nest + loop_nest.check() + + # Apply the schedule to the scheduler + self._apply_loop_nest(loop_nest, scheduler) def parse_yaml(self, spec: str) -> dict[str, dict]: dspec = strictyaml.load(spec).data @@ -161,41 +683,26 @@ def parse_yaml(self, spec: str) -> dict[str, dict]: def _parse_yaml(self, spec: dict[str, dict]) -> dict[str, dict]: out_dict = {} - for level, d_level in spec.items(): - level_dict = {} - if not isinstance(d_level, dict): - continue - for a, v in d_level.items(): - if a == "explore": - assert isinstance(v, str) - if v == "": - tmp = None - else: - try: - tmp = eval(v) - except NameError: - tmp = v - level_dict["explore_axis_order"] = tmp - elif a in self.abstract_matrix: - assert isinstance(v, str) - level_dict[a] = self._split_yaml(v) + for a, v in spec.items(): + if a in self.abstract_matrix: + assert isinstance(v, str) + out_dict[a] = self._split_yaml(v) + else: + if isinstance(v, str): + d = self._split_yaml(v) else: - if isinstance(v, str): - d = self._split_yaml(v) - else: - assert isinstance(v, dict) - d = v - size = d.get("size", None) - if size: - d.pop("size") - a = f"{a}#{size}" - if ":" in a: - level_dict[a] = self._parse_yaml(d) - continue - level_dict[a] = {} - for axis_arg, arg_val in d.items(): - level_dict[a][axis_arg] = arg_val - out_dict[level] = level_dict + assert isinstance(v, dict) + d = v + size = d.get("size", None) + if size: + d.pop("size") + a = f"{a}#{size}" + if ":" in a: + out_dict[a] = self._parse_yaml(d) + continue + out_dict[a] = {} + for axis_arg, arg_val in d.items(): + out_dict[a][axis_arg] = arg_val return out_dict def _split_yaml(self, s: str) -> dict[str, Any]: @@ -211,488 +718,3 @@ def _split_yaml(self, s: str) -> dict[str, Any]: tmp = y d[x] = tmp return d - - def flatten_schedule(self, node_name: str, spec: dict[str, dict] | str): - if isinstance(spec, str): - dict_spec = self.parse_yaml(spec) - else: - dict_spec = spec - flat_schedules = self._flatten_schedule(root=node_name, spec=dict_spec, head=[]) - variables = [] - constraints = [] - axes = {} - orders = {} - for schedule in flat_schedules.slices: - if isinstance(schedule, LoopNestSliceExtend): - variables += schedule.variables - constraints += schedule.constraints - for axis, order in schedule.axes.items(): - axes[f"order_{axis}"] = order - axis_orders = schedule.axis_orders - for axis in axis_orders: - orders[axis] = schedule.axes[axis] - - variables = list(dict.fromkeys(variables)) - constraints = list(dict.fromkeys(constraints)) - return (flat_schedules, variables, constraints, axes, orders) - - def apply_sample( - self, flat_schedules: LoopNestExtend, sample: dict[str, Any] - ) -> LoopNestExtend: - flat_schedules = deepcopy(flat_schedules) - flat_schedules.apply_sample(sample) - return flat_schedules - - def apply_scheduler(self, flat_schedules: LoopNestExtend, scheduler: Scheduler): - flat_schedules.check() - for schedule in flat_schedules.slices: - assert isinstance(schedule, LoopNestSliceExtend) - root = schedule.root - interchange = [] - - for d, s in schedule.axes.items(): - s = list(s.values()) - for s in s: - interchange += s - - p = schedule.packs.get(d, None) - if p: - for _, input, pad in p: - scheduler.pack_at(s[-1], input, pad=pad) - - b = schedule.buffers.get(d, None) - if b: - scheduler.buffer_at(s[-1]) - - for d, s in schedule.splits.items(): - s = correct_type(s) - scheduler.split(d, s, root=root) - - for d, s in schedule.tiles.items(): - s = correct_type(s) - scheduler.tile(d, s, root=root) - - scheduler.interchange(interchange, root=root) - scheduler.vectorize(schedule.vectorize, root=root) - scheduler.parallelize(schedule.parallelize, root=root) - s = correct_type(schedule.unroll) - scheduler.unroll(s, root=root) - - @override - def _flatten_schedule( - self, - root: str, - spec: dict[str, dict], - head: list[str], - tile_sizes: dict[str, int | str] | None = None, - sched_sizes: dict[str, list] | None = None, - ) -> LoopNestExtend: - recursive_scheds = LoopNestExtend(abstract_dims=self.abstract_axis) - sched = recursive_scheds.build_slice(root) - # State of the schedule - if tile_sizes: - axes_sizes: dict[str, int | str] = tile_sizes - else: - axes_sizes = {a: v for a, v in self.abstract_axis_sizes.items()} - if sched_sizes is None: - sched_sizes = {} - for a, v in axes_sizes.items(): - sched_sizes[a] = [str(v)] - sizes: dict[str, int | str | None] = {} - previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} - interchange: list[str] = head - # Processing the schedule - for tree_declaration, tree_val in spec.items(): - assert isinstance(tree_val, dict) - tree_interchange = {} - tree_packs = [] - tree_fusion = [] - tree_buff = [] - last_split = None - for declaration, val in tree_val.items(): - if declaration == "fusion": - tree_fusion.append(val) - continue - elif declaration == "pack": - for val_ in val: - if len(val_) != 3: - raise Exception(f"Packing {val_} should have 3 parameters.") - param, input, pad = val_ - tree_packs.append((param, input, pad)) - if isinstance(param, str): - sched.variables.add(param) - sched.constraints.add(f"{param} in {{0, 1}}") - if isinstance(input, str): - input = self.abstract_matrix.index(input) - if isinstance(pad, str): - sched.variables.add(pad) - sched.constraints.add(f"{pad} in {{0, 1}}") - continue - elif declaration in "buffer": - for val_ in val: - if len(val_) != 2: - raise Exception( - f"Bufferisation {val_} should have 2 parameters." - ) - param, pad = val_ - tree_buff.append((param, pad)) - if isinstance(param, str): - sched.variables.add(param) - sched.constraints.add(f"{param} in {{0, 1}}") - if isinstance(pad, str): - sched.variables.add(pad) - sched.constraints.add(f"{pad} in {{0, 1}}") - continue - elif declaration == "explore_axis_order": - sched.axis_orders.append(tree_declaration) - continue - elif declaration in self.abstract_matrix: - matrix_index = self.abstract_matrix.index(declaration) - param = val.get("bufferize", False) - pad = val.get("pad", False) - if param is None or param: - if matrix_index == len(self.abstract_matrix) - 1: - tree_buff.append((param, pad)) - else: - tree_packs.append((param, matrix_index, pad)) - if isinstance(param, str): - sched.variables.add(param) - sched.constraints.add(f"{param} in {{0, 1}}") - if isinstance(pad, str): - sched.variables.add(pad) - sched.constraints.add(f"{pad} in {{0, 1}}") - continue - elif ":" in declaration: - axis_name, x, y, z = self.parse_split_declaration(declaration) - self._check_axis_existence(axis_name) - - # The only declaration where y (the cut) is None is the - # last one, so it cannot be the previous one. - cut = previous_cut[axis_name] - - current_size = axes_sizes[axis_name] - # Update the previous cut - # Save the cutting points of the new dimensions - if axis_name not in sched.splits: - sched.splits[axis_name] = {} - new_dim_index = len(sched.splits[axis_name]) - new_dim_name = f"{axis_name}[{new_dim_index}]" - new_axes_root_name = f"{root}/{new_dim_name}" - if axis_name in tree_interchange: - tree_interchange[axis_name].append(new_dim_name) - else: - tree_interchange[axis_name] = [new_dim_name] - - if z is None: - previous_cut[axis_name] = y - # When x (the starting point of the slice), is not - # specified, it is the previous cut - if x is None: - x = cut - assert isinstance(x, int | str) - sched.splits[axis_name][new_dim_name] = x - - # assert isinstance(x, int) - inner_size = self._extended_check_splitting_intervals( - declaration, axis_name, cut, x, y - ) - inner_size = None - if y is None: - y = current_size - if isinstance(x, int): - if x == 0: - inner_size = y - elif isinstance(y, int): - inner_size = y - x - - if inner_size is None: - inner_size = root[1:] + new_dim_name - inner_size = ( - inner_size.replace("/", "") - .replace("[", "_") - .replace("]", "_") - ) - sched.constraints.add(f"{inner_size} <= {y}") - if isinstance(x, str): - sched.constraints.add(f"{x} <= {y}") - sched.constraints.add(f"{inner_size} + {x} == {y}") - else: - inner_size = z - x = cut - y = current_size - if isinstance(z, int) and isinstance(x, int): - previous_cut[axis_name] = x + z - if not isinstance(y, int): - sched.constraints.add(f"{z + x} <= {y}") - elif isinstance(x, int) and x == 0: - previous_cut[axis_name] = z - if not isinstance(y, int): - sched.constraints.add(f"{z} <= {y}") - else: - new_cut = root[1:] + new_dim_name - new_cut = ( - new_cut.replace("/", "") - .replace("[", "_") - .replace("]", "_") - ) - previous_cut[axis_name] = new_cut - if last_split is not None: - a, b = last_split - sched.constraints.add(f"{a} <= {b}") - last_split = (new_cut, y) - sched.constraints.add(f"{z} + {x} == {new_cut}") - - axes_sizes[axis_name] = inner_size - - # Fetch the schedule associated with the new dimension - next_schedule = val - assert isinstance(next_schedule, dict) - inner_scheds = self._flatten_schedule( - spec=next_schedule, - root=new_axes_root_name, - tile_sizes=axes_sizes.copy(), - head=[axis_name], - sched_sizes=deepcopy(sched_sizes), - ) - axes_sizes[axis_name] = current_size - - recursive_scheds.slices += inner_scheds.slices - continue - - elif "#" in declaration: - axis_name, tile_size = declaration.split("#") - self._check_axis_existence(axis_name) - assert isinstance(tile_size, str) - if tile_size.isdecimal(): - loop_size = int(tile_size) - else: - loop_size = tile_size - sched.variables.add(tile_size) - if not loop_size: - raise Exception( - f"Invalid tile size: '{tile_size}' in {declaration}" - ) - - if isinstance(loop_size, str): - partial = "partial" in val - full = "full" in val - if partial and full: - raise Exception( - f"Tile {declaration} cannot be partial and full" - ) - if partial or (not full and self.partial_tiles): - sched.constraints.add( - f"{loop_size} <= {axes_sizes[axis_name]}" - ) - else: - s = ( - ", ".join(sched_sizes[axis_name]) - if len(sched_sizes[axis_name]) > 1 - else sched_sizes[axis_name][0] - ) - s = f"{loop_size} || {{{s}}}" - sched.constraints.add(s) - sched_sizes[axis_name].insert(0, str(loop_size)) - axes_sizes[axis_name] = loop_size - tile_num = len(sched.tiles[axis_name]) - loop_name = f"{axis_name}{tile_num}" - sched.tiles[axis_name][loop_name] = loop_size - sizes[loop_name] = loop_size - if axis_name in tree_interchange: - raise Exception( - f"axis {axis_name} already is used in level {tree_declaration}." - ) - tree_interchange[axis_name] = [loop_name] - elif declaration in self.abstract_axis: - loop_name = declaration - axis_name = loop_name - if loop_name in tree_interchange: - raise Exception( - f""" - Axis {declaration} is scheduled twice (or more). - """ - ) - tree_interchange[loop_name] = [loop_name] - else: - raise Exception( - f""" - Axis {declaration} is not a defined axis. - Known axis are: {self.abstract_axis}") - """ - ) - - self.annotate( - loop_name=loop_name, - sizes=sizes, - annotations=val, - sched=sched, - ) - sched.axes[tree_declaration] = tree_interchange - if len(tree_packs) > 0: - sched.packs[tree_declaration] = tree_packs - if len(tree_fusion) > 0: - sched.fusions[tree_declaration] = tree_fusion - if len(tree_buff) > 0: - sched.buffers[tree_declaration] = tree_buff - for v in tree_interchange.values(): - interchange += v - - if last_split is not None: - a, b = last_split - if isinstance(a, int) and not isinstance(b, int): - a, b = b, a - a, b = str(a), str(b) - for c in sched.constraints: - sched.constraints.remove(c) - sched.constraints.add(c.replace(a, b)) - last_split = None - - # Check if the last cut of each axis is either 0 or None. - # None correspond to "until the end of the loop". 0 is the - # default value, if it has 0 then it means the axis isn't splitted. - # Any other value means the split is let in a partial state. - for axis, cut in previous_cut.items(): - if cut is not None and isinstance(cut, int) and cut != 0: - raise Exception( - f"Splitting on axis {axis} should end but stops at {cut}" - ) - - sched.interchange = interchange - return recursive_scheds - - def _extended_check_splitting_intervals( - self, - declaration: str, - axis_name: str, - cut: int | str | None, - x: int | str | None, - y: int | str | None, - ) -> int | str | None: - if cut is None: - raise Exception( - f""" - {declaration} is defined on an already covered axis. - This might be caused by a missing endpoint: {axis_name} - """ - ) - - assert isinstance(x, int | str) - - if isinstance(cut, int) and isinstance(x, int): - if x > cut: - raise Exception( - f""" - Splitting doesn't cover the whole axis - (jumps from {cut} to {x} on axis {axis_name}) - """ - ) - elif x < cut: - raise Exception( - f""" - Splitting are overlapping on axis {axis_name} - (covered until {cut} but restart at {x}) - """ - ) - else: - if x != cut: - raise Exception( - f""" - Splitting should use the same variables between an end and a start - ({cut} and {x} on axis {axis_name}) - """ - ) - assert x == cut - if y is None: - return None - - if isinstance(x, int): - if isinstance(y, int): - if x >= y: - raise Exception( - f""" - Starting point in the splitting cannot be greater or equal to - the ending point in: {declaration} - """ - ) - else: - return y - x - if x == 0: - return y - return None - - def annotate( - self, - loop_name: str, - sizes: dict[str, int | str | None], - annotations: dict[str, Any], - sched: LoopNestSliceExtend, - ): - for instr, param in annotations.items(): - assert isinstance(instr, str) - match instr: - case "unroll": - if param is None and loop_name in sizes: - ufactor = sizes[loop_name] - else: - ufactor = param - if isinstance(param, str): - sched.variables.add(param) - leq = "<=" if self.partial_unrolls else "||" - sched.constraints.add(f"{ufactor} {leq} {sizes[loop_name]}") - assert isinstance(ufactor, int | str) - sched.unroll[loop_name] = ufactor - - case "vectorize": - if isinstance(param, str): - sched.variables.add(param) - sched.constraints.add(f"{param} in {{0, 1}}") - sched.vectorize_bool.add((param, loop_name)) - continue - if param is None: - sched.vectorize.append(loop_name) - continue - raise Exception( - "Vectorize should not have a parameter (Feature not implemented)" - ) - - case "parallelize": - if isinstance(param, str): - sched.variables.add(param) - sched.constraints.add(f"{param} in {{0, 1}}") - sched.parallelize_bool.add((param, loop_name)) - continue - if param is None: - sched.parallelize.append(loop_name) - continue - if param is not None: - raise Exception( - "Parallelize should not have a parameter (Feature not implemented)" - ) - case "partial": - continue - case "full": - continue - case _: - raise Exception(f"Unknown annotation on {loop_name}: {instr}") - - def parse_split_declaration( - self, - declaration: str, - ) -> Tuple[str, int | str | None, int | str | None, int | str | None]: - pattern = r"^(.*)\[(?:(-\w+|\w*)?):(?:(-\w+|\w*)?)\]$" - match = re.match(pattern, declaration) - if not match: - pattern = r"^(.*)\[:(\w*):\]$" - match = re.match(pattern, declaration) - if not match: - raise Exception(f"Wrong format {declaration}") - prefix, z = match.groups() - z = int(z) if z.isnumeric() else z - return prefix, None, None, z - - prefix, x_str, y_str = match.groups() - x = int(x_str) if x_str.isnumeric() else x_str - y = int(y_str) if y_str.isnumeric() else y_str - x = x if x else None - y = y if y else None - return prefix, x, y, None diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index bddb1f53a..73b5cc638 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -20,7 +20,7 @@ from xtc.itf.schd import Scheduler from xtc.itf.schd.scheduler import DEFAULT_ROOT from xtc.itf.search import Sample, Strategy -from xtc.schedules.descript_extend import DescriptExtend +from xtc.schedules.descript_extend import DescriptExtend, LoopNestSliceExtend from xtc.utils.math import ( factors_to_sizes, factors_enumeration, @@ -953,7 +953,7 @@ class Strategy_Descript(Strategy): def __init__( self, graph: Graph, - spec: dict[str, dict] | str, + spec: dict[str, dict[str, Any]] | str, constraints: list[str] = [], partial_tiles: bool = False, partial_unrolls: bool = False, @@ -964,6 +964,7 @@ def __init__( self._stats: dict[str, int] = {} self._axes = list(self._op.dims) self._sizes = self._constant_sizes() + self._sample_names: list[str] = [] descript = DescriptExtend( abstract_axis=self._axes, abstract_axis_sizes=dict(self._sizes), @@ -973,26 +974,19 @@ def __init__( ) self._descript = descript self._initialized = False - input_constraints = constraints - self._flat_schedules, self._sample_names, constraints, axes, orders = ( - descript.flatten_schedule(node_name=DEFAULT_ROOT, spec=spec) - ) + loop_nest = descript.loop_nest(node_name=DEFAULT_ROOT, spec=spec) + self._loop_nest = loop_nest + input_constraints: list[str] = [] + for slice in loop_nest.slices: + assert isinstance(slice, LoopNestSliceExtend) + input_constraints += slice.constraints + self._sample_names += slice.variables for a, v in self._sizes.items(): for i, s in enumerate(constraints): assert isinstance(s, str) constraints[i] = s.replace(f"[{a}]", str(v)) - self._axes_names = {} - for a, v in axes.items(): - self._axes_names[a] = v self._orders: dict[str, list] = {} - order_constraints: list[str] = [] - for a, v in orders.items(): - assert isinstance(v, dict) - permutation = list(itertools.permutations(v)) - a_holder = f"order_{a}" - self._orders[a_holder] = permutation - order_constraints.append(f"{a_holder} in {set(range(len(permutation)))}") - self._constraints = constraints + input_constraints + order_constraints + self._constraints = constraints + input_constraints self._constraints.sort() if initialize: self._initialize() @@ -1025,15 +1019,13 @@ def graph(self) -> Graph: @override def generate(self, scheduler: Scheduler, sample: Sample) -> None: descript = self._descript - flat_schedules = self._flat_schedules - for a, p in self._orders.items(): - if a in sample: - if isinstance(sample[a], int): - sample[a] = p[sample[a]] - flat_schedules = descript.apply_sample( - flat_schedules=flat_schedules, sample=sample + # for a, p in self._orders.items(): + # if a in sample: + # if isinstance(sample[a], int): + # sample[a] = p[sample[a]] + descript.apply_sample( + loop_nest=self._loop_nest, scheduler=scheduler, sample=sample ) - descript.apply_scheduler(flat_schedules, scheduler) @override def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: diff --git a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir index cb29c2ca2..67bc2f16d 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir @@ -4,34 +4,35 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 linalg.matmul { loop.dims = ["i", "j"], loop.schedule = { - "One" = { - "i[:5]" = { "Two" = {"j"} }, - "i[5:]" = { "Two" = {"j"} }, - "fusion" + "i[:5]" = { "j" }, + "i[5:]" = { "j" } } - } } ins(%A, %B : memref<256x512xf64>, memref<512x256xf64>) outs(%C: memref<256x256xf64>) return -}// CHECK: // -----// IR Dump Before transform //----- // -// CHECK-NEXT: module attributes {transform.with_named_sequence} { -// CHECK-NEXT: func.func @matmul(%arg0: memref<256x512xf64> {llvm.noalias}, %arg1: memref<512x256xf64> {llvm.noalias}, %arg2: memref<256x256xf64> {llvm.noalias}) { -// CHECK-NEXT: linalg.matmul {__node0__} ins(%arg0, %arg1 : memref<256x512xf64>, memref<512x256xf64>) outs(%arg2 : memref<256x256xf64>) -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { -// CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op -// CHECK-NEXT: transform.yield -// CHECK-NEXT: } -// CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { -// CHECK-NEXT: %0 = transform.structured.match attributes {__node0__} in %arg0 : (!transform.any_op) -> !transform.any_op -// CHECK-NEXT: %1 = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op -// CHECK-NEXT: %2:2 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2#0 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -// CHECK-NEXT: transform.annotate %loops "__node0__/i[0]/j" : !transform.any_op -// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %2#1 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -// CHECK-NEXT: transform.annotate %loops_1 "__node0__/i[1]/j" : !transform.any_op -// CHECK-NEXT: transform.yield -// CHECK-NEXT: } +} +// CHECK: module attributes {transform.with_named_sequence} { +// CHECK-NEXT: func.func @matmul(%arg0: memref<256x512xf64> {llvm.noalias}, %arg1: memref<512x256xf64> {llvm.noalias}, %arg2: memref<256x256xf64> {llvm.noalias}) { +// CHECK-NEXT: linalg.matmul {__node0__} ins(%arg0, %arg1 : memref<256x512xf64>, memref<512x256xf64>) outs(%arg2 : memref<256x256xf64>) +// CHECK-NEXT: return // CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +// CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +// CHECK-NEXT: %0 = transform.structured.match attributes {__node0__} in %arg0 : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: %1 = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op +// CHECK-NEXT: %2:2 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2#0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops "__node0__/i[0]/i" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_1 "__node0__/i[0]/j" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2#1 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_3 "__node0__/i[1]/i" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_5 "__node0__/i[1]/j" : !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT:} diff --git a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir index dff01ca7d..b2171be06 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir @@ -15,4 +15,5 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 outs(%C: memref<256x256xf64>) return } -// CHECK: `j#a`: a is not an integer. +// CHECK: xtc.schedules.descript.ScheduleInterpretError: `j#a`: a is not an integer. + diff --git a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir index 4c6eae3bd..ab21ab0ec 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir @@ -14,4 +14,4 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 outs(%C: memref<256x256xf64>) return } -// CHECK: `k#-1`: tile sizes should be strictly positive. +// CHECK: xtc.schedules.descript.ScheduleInterpretError: `k#-1`: -1 is not an integer. diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py index 3246d7022..a226cff38 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py @@ -23,15 +23,11 @@ node_name="C", abstract_axis=["i", "j", "k"], spec={ - "DDR": { "k": {}, "i": {}, "j": {}, - }, - "R": { "i#i_inner": {"unroll": "i_unroll"}, "j#j_inner": {"vectorize": "j_vectorize"}, - }, }, abstract_axis_sizes=axes_sizes, sample={"i_inner": 2, "j_inner": 16, "i_unroll": None, "j_vectorize": None}, diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py index b1d41aba2..0e21fa5c8 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py @@ -24,24 +24,16 @@ abstract_axis=["i", "j", "k"], abstract_axis_sizes=axes_sizes, spec={ - "DDR": { "j": {}, "k": {}, - }, - "L2": { "j#jDDR": {}, "i[:4]": { - "R": { "i#iR1": {"unroll": None}, "j#jR": {"vectorize": None}, - }, }, "i[4:]": { - "R": { "i#iR2": {"unroll": None}, "j#jR": {"vectorize": None}, - }, - }, }, }, sample={"jDDR": 16, "jR": 4, "iR1": 2, "iR2": 4}, @@ -87,14 +79,18 @@ #CHECK-NEXT: transform.annotate %loops_7 "C/j0" : !transform.any_op #CHECK-NEXT: %2 = transform.structured.split %tiled_linalg_op_6 after 4 {dimension = 0 : i64} : !transform.any_op #CHECK-NEXT: %3:2 = transform.split_handle %2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %3#0 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i0" : !transform.any_op -#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () -#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op -#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %3#1 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) -#CHECK-NEXT: transform.annotate %loops_11 "C/i[1]/i0" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %3#0 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %tiled_linalg_op_8 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_11 "C/i[0]/i0" : !transform.any_op #CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () -#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 4 : i64} : !transform.any_op +#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_12, %loops_13 = transform.structured.tile_using_for %3#1 tile_sizes [4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_13 "C/i[1]/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_14, %loops_15 = transform.structured.tile_using_for %tiled_linalg_op_12 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_15 "C/i[1]/i0" : !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_14) : (!transform.any_op) -> () +#CHECK-NEXT: transform.loop.unroll %loops_15 {factor = 4 : i64} : !transform.any_op #CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op #CHECK-NEXT: transform.apply_patterns to %4 { #CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract @@ -143,87 +139,87 @@ #CHECK-NEXT: %subview_7 = memref.subview %subview_3[0, 0] [4, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<4x1xf32, strided<[512, 1], offset: ?>> #CHECK-NEXT: %subview_8 = memref.subview %subview_6[0, 0] [4, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<4x4xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: scf.for %arg6 = %c0 to %c4 step %c2 { -#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%arg6, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%arg6, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %1 = vector.transfer_read %subview_11[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%arg6, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%arg6, 0] [2, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<2x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_13 = memref.subview %subview_11[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_12[%c0, 0] [1, 4] [1, 1] : memref<2x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> #CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> #CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> #CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> #CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> #CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> #CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> #CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %10 = arith.addi %arg6, %c1 : index -#CHECK-NEXT: %subview_13 = memref.subview %subview_7[%10, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_14 = memref.subview %subview_8[%10, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %11 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %13 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %14 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<4xf32> -#CHECK-NEXT: %17 = vector.extract %13[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<4xf32> -#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %19, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: } {"C/i[0]/i0"} +#CHECK-NEXT: vector.transfer_write %9, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_15 = memref.subview %subview_11[%c1, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_12[%c1, 0] [1, 4] [1, 1] : memref<2x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[0]/i"} #CHECK-NEXT: %subview_9 = memref.subview %subview_3[4, 0] [12, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<12x1xf32, strided<[512, 1], offset: ?>> #CHECK-NEXT: %subview_10 = memref.subview %subview_6[4, 0] [12, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<12x4xf32, strided<[32, 1], offset: ?>> #CHECK-NEXT: scf.for %arg6 = %c0 to %c12 step %c4 { -#CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %1 = vector.transfer_read %subview_11[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [4, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [4, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<4x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_13 = memref.subview %subview_11[%c0, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_12[%c0, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> #CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %3 = vector.transfer_read %subview_12[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> #CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> #CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> #CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> #CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> #CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> #CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %9, %subview_12[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %10 = arith.addi %arg6, %c1 : index -#CHECK-NEXT: %subview_13 = memref.subview %subview_9[%10, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_14 = memref.subview %subview_10[%10, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %11 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %12 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %13 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %14 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %15 = vector.extract %11[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %16 = vector.broadcast %15 : f32 to vector<4xf32> -#CHECK-NEXT: %17 = vector.extract %13[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %18 = vector.fma %16, %14, %17 : vector<4xf32> -#CHECK-NEXT: %19 = vector.insert %18, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %19, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %20 = arith.addi %arg6, %c2 : index -#CHECK-NEXT: %subview_15 = memref.subview %subview_9[%20, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_16 = memref.subview %subview_10[%20, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %21 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %22 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %23 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %24 = vector.extract %22[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %25 = vector.extract %21[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %26 = vector.broadcast %25 : f32 to vector<4xf32> -#CHECK-NEXT: %27 = vector.extract %23[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %28 = vector.fma %26, %24, %27 : vector<4xf32> -#CHECK-NEXT: %29 = vector.insert %28, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %29, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %30 = arith.addi %arg6, %c3 : index -#CHECK-NEXT: %subview_17 = memref.subview %subview_9[%30, 0] [1, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> -#CHECK-NEXT: %subview_18 = memref.subview %subview_10[%30, 0] [1, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: %31 = vector.transfer_read %subview_17[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> -#CHECK-NEXT: %32 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %33 = vector.transfer_read %subview_18[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> -#CHECK-NEXT: %34 = vector.extract %32[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %35 = vector.extract %31[0, 0] : f32 from vector<1x1xf32> -#CHECK-NEXT: %36 = vector.broadcast %35 : f32 to vector<4xf32> -#CHECK-NEXT: %37 = vector.extract %33[0] : vector<4xf32> from vector<1x4xf32> -#CHECK-NEXT: %38 = vector.fma %36, %34, %37 : vector<4xf32> -#CHECK-NEXT: %39 = vector.insert %38, %cst [0] : vector<4xf32> into vector<1x4xf32> -#CHECK-NEXT: vector.transfer_write %39, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> -#CHECK-NEXT: } {"C/i[1]/i0"} +#CHECK-NEXT: vector.transfer_write %9, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_15 = memref.subview %subview_11[%c1, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_12[%c1, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_17 = memref.subview %subview_11[%c2, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_18 = memref.subview %subview_12[%c2, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %19 = vector.transfer_read %subview_17[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %20 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %21 = vector.transfer_read %subview_18[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %22 = vector.extract %20[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %23 = vector.extract %19[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %24 = vector.broadcast %23 : f32 to vector<4xf32> +#CHECK-NEXT: %25 = vector.extract %21[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %26 = vector.fma %24, %22, %25 : vector<4xf32> +#CHECK-NEXT: %27 = vector.insert %26, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %27, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_19 = memref.subview %subview_11[%c3, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_20 = memref.subview %subview_12[%c3, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %28 = vector.transfer_read %subview_19[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %29 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %30 = vector.transfer_read %subview_20[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %31 = vector.extract %29[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %32 = vector.extract %28[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %33 = vector.broadcast %32 : f32 to vector<4xf32> +#CHECK-NEXT: %34 = vector.extract %30[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %35 = vector.fma %33, %31, %34 : vector<4xf32> +#CHECK-NEXT: %36 = vector.insert %35, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %36, %subview_20[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[1]/i"} #CHECK-NEXT: } {"C/j0"} #CHECK-NEXT: } {"C/k"} #CHECK-NEXT: } {"C/j"} diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py index 4e6aedc89..cbe63694c 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py @@ -24,28 +24,17 @@ node_name="C", abstract_axis=["i", "j", "k"], abstract_axis_sizes=axes_sizes, + abstract_matrix=["A", "B", "C"], spec={ - "DDR": { "j": {"parallelize": "par"}, "k": {}, "i": {}, - # "explore_axis_order": True, - "pack": [("pack_B", 1, True), ("pack_A", 0, True)], - }, - # "DDRk": { - # }, - # "DDRi": { - # }, - "L3": { + "B": {"bufferize": "pack_B"}, + "A": {"bufferize": "pack_A"}, "j#jL3": {}, - }, - "L2": { "i#iL2": {}, - }, - "L1": { "k#kL1": {"unroll": "k_unroll"}, - }, - "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": None}}, + "i#iR": {"unroll": None}, "j#jR": {"vectorize": None}, }, sample={ "par": None, @@ -100,9 +89,7 @@ #CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) #CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) #CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] -#CHECK-NEXT:INPS = list(obj.values())[:-1] #CHECK-NEXT:O = obj['C'] -#CHECK-NEXT:I_R0 = sch.cache_read(INPS[0], "local", [O]) #CHECK-NEXT:i, j, = O.op.axis #CHECK-NEXT:k, = O.op.reduce_axis #CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) @@ -113,8 +100,6 @@ #CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) #CHECK-NEXT:j1, __v_j1 = sch[O].split(j1, factor=2) #CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1, __v_j1) -#CHECK-NEXT:sch[I_R0].compute_at(sch[O], i) -#CHECK-NEXT:sch[I_R0].storage_align(I_R0.op.axis[-2], factor=1024, offset=16) #CHECK-NEXT:sch[O].unroll(__u_k0) #CHECK-NEXT:sch[O].unroll(i1) #CHECK-NEXT:sch[O].unroll(j1) @@ -130,7 +115,6 @@ #CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): #CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) #CHECK-NEXT: for j_outer in T.parallel(15): -#CHECK-NEXT: _0_local = T.allocate([2048], "float32", "local") #CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) #CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): #CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 128): @@ -145,72 +129,79 @@ #CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514 + 2] = T.Broadcast(T.float32(0.0), 2) #CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 127): #CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516 + 2] = T.Broadcast(T.float32(0.0), 2) -#CHECK-NEXT: for k_outer, i_outer in T.grid(32, 4): -#CHECK-NEXT: _0_local_1 = T.Buffer((2048,), data=_0_local, scope="local") -#CHECK-NEXT: for ax0, ax1 in T.grid(128, 16): -#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) -#CHECK-NEXT: _0_local_1[ax0 * 16 + ax1] = _0_1[i_outer * 65536 + ax0 * 512 + k_outer * 16 + ax1] -#CHECK-NEXT: for j_inner_outer, i_inner_outer, k_inner_outer in T.grid(6, 64, 8): -#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): -#CHECK-NEXT: cse_var_3: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_2: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_1: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_3 + cse_var_2 -#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 2] = C_1[cse_var_1:cse_var_1 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_3 + cse_var_2 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): -#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_5: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_4: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2 -#CHECK-NEXT: C_1[cse_var_4:cse_var_4 + 2] = C_1[cse_var_4:cse_var_4 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_6 + cse_var_5 + 2 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): -#CHECK-NEXT: cse_var_9: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_8: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_7: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4 -#CHECK-NEXT: C_1[cse_var_7:cse_var_7 + 2] = C_1[cse_var_7:cse_var_7 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_9 + cse_var_8 + 4 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): -#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_12 + cse_var_11 + 512 -#CHECK-NEXT: C_1[cse_var_10:cse_var_10 + 2] = C_1[cse_var_10:cse_var_10 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11:k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): -#CHECK-NEXT: cse_var_15: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_14: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_13: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_15 + cse_var_14 + 514 -#CHECK-NEXT: C_1[cse_var_13:cse_var_13 + 2] = C_1[cse_var_13:cse_var_13 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_15 + cse_var_14 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_15 + cse_var_14 + 2 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): -#CHECK-NEXT: cse_var_18: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_17: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_16: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_18 + cse_var_17 + 516 -#CHECK-NEXT: C_1[cse_var_16:cse_var_16 + 2] = C_1[cse_var_16:cse_var_16 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 16], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_18 + cse_var_17 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_18 + cse_var_17 + 4 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): -#CHECK-NEXT: cse_var_21: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_20: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_19: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_21 + cse_var_20 -#CHECK-NEXT: C_1[cse_var_19:cse_var_19 + 2] = C_1[cse_var_19:cse_var_19 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_21 + cse_var_20 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_21 + cse_var_20 + 512 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): -#CHECK-NEXT: cse_var_24: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_23: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_22: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_24 + cse_var_23 + 2 -#CHECK-NEXT: C_1[cse_var_22:cse_var_22 + 2] = C_1[cse_var_22:cse_var_22 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 514 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): -#CHECK-NEXT: cse_var_27: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_26: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_25: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_27 + cse_var_26 + 4 -#CHECK-NEXT: C_1[cse_var_25:cse_var_25 + 2] = C_1[cse_var_25:cse_var_25 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_27 + cse_var_26 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_27 + cse_var_26 + 516 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): -#CHECK-NEXT: cse_var_30: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_29: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_28: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512 -#CHECK-NEXT: C_1[cse_var_28:cse_var_28 + 2] = C_1[cse_var_28:cse_var_28 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_30 + cse_var_29 + 512 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): -#CHECK-NEXT: cse_var_33: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_32: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_31: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514 -#CHECK-NEXT: C_1[cse_var_31:cse_var_31 + 2] = C_1[cse_var_31:cse_var_31 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_33 + cse_var_32 + 514 + 2] -#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): -#CHECK-NEXT: cse_var_36: T.int32 = j_outer * 36 -#CHECK-NEXT: cse_var_35: T.int32 = j_inner_outer * 6 -#CHECK-NEXT: cse_var_34: T.int32 = i_outer * 65536 + i_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 -#CHECK-NEXT: C_1[cse_var_34:cse_var_34 + 2] = C_1[cse_var_34:cse_var_34 + 2] + T.Broadcast(_0_local_1[i_inner_outer * 32 + k_inner_outer * 2 + 17], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 + 2] +#CHECK-NEXT: for k_outer, i_outer, j_inner_outer, i_inner_outer, k_inner_outer in T.grid(32, 4, 6, 64, 8): +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_4: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_3: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_2: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + cse_var_4 + cse_var_3 +#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 2] = C_1[cse_var_1:cse_var_1 + 2] + T.Broadcast(_0_1[cse_var_2 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_4 + cse_var_3:k_outer * 8192 + k_inner_outer * 1024 + cse_var_4 + cse_var_3 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_8: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_7: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_6: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_5: T.int32 = cse_var_6 + cse_var_8 + cse_var_7 + 2 +#CHECK-NEXT: C_1[cse_var_5:cse_var_5 + 2] = C_1[cse_var_5:cse_var_5 + 2] + T.Broadcast(_0_1[cse_var_6 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_8 + cse_var_7 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_8 + cse_var_7 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_9: T.int32 = cse_var_10 + cse_var_12 + cse_var_11 + 4 +#CHECK-NEXT: C_1[cse_var_9:cse_var_9 + 2] = C_1[cse_var_9:cse_var_9 + 2] + T.Broadcast(_0_1[cse_var_10 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_16: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_15: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_14: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_13: T.int32 = cse_var_14 + cse_var_16 + cse_var_15 + 512 +#CHECK-NEXT: C_1[cse_var_13:cse_var_13 + 2] = C_1[cse_var_13:cse_var_13 + 2] + T.Broadcast(_0_1[cse_var_14 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_16 + cse_var_15:k_outer * 8192 + k_inner_outer * 1024 + cse_var_16 + cse_var_15 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_20: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_19: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_18: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_17: T.int32 = cse_var_18 + cse_var_20 + cse_var_19 + 514 +#CHECK-NEXT: C_1[cse_var_17:cse_var_17 + 2] = C_1[cse_var_17:cse_var_17 + 2] + T.Broadcast(_0_1[cse_var_18 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_20 + cse_var_19 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_20 + cse_var_19 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_24: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_23: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_22: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_21: T.int32 = cse_var_22 + cse_var_24 + cse_var_23 + 516 +#CHECK-NEXT: C_1[cse_var_21:cse_var_21 + 2] = C_1[cse_var_21:cse_var_21 + 2] + T.Broadcast(_0_1[cse_var_22 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_28: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_27: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_26: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_25: T.int32 = cse_var_26 + cse_var_28 + cse_var_27 +#CHECK-NEXT: C_1[cse_var_25:cse_var_25 + 2] = C_1[cse_var_25:cse_var_25 + 2] + T.Broadcast(_0_1[cse_var_26 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_28 + cse_var_27 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_28 + cse_var_27 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_32: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_31: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_30: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_29: T.int32 = cse_var_30 + cse_var_32 + cse_var_31 + 2 +#CHECK-NEXT: C_1[cse_var_29:cse_var_29 + 2] = C_1[cse_var_29:cse_var_29 + 2] + T.Broadcast(_0_1[cse_var_30 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_32 + cse_var_31 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_32 + cse_var_31 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_36: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_35: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_34: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_33: T.int32 = cse_var_34 + cse_var_36 + cse_var_35 + 4 +#CHECK-NEXT: C_1[cse_var_33:cse_var_33 + 2] = C_1[cse_var_33:cse_var_33 + 2] + T.Broadcast(_0_1[cse_var_34 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_40: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_39: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_38: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_37: T.int32 = cse_var_38 + cse_var_40 + cse_var_39 + 512 +#CHECK-NEXT: C_1[cse_var_37:cse_var_37 + 2] = C_1[cse_var_37:cse_var_37 + 2] + T.Broadcast(_0_1[cse_var_38 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_40 + cse_var_39 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_40 + cse_var_39 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_44: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_43: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_42: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_41: T.int32 = cse_var_42 + cse_var_44 + cse_var_43 + 514 +#CHECK-NEXT: C_1[cse_var_41:cse_var_41 + 2] = C_1[cse_var_41:cse_var_41 + 2] + T.Broadcast(_0_1[cse_var_42 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_44 + cse_var_43 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_44 + cse_var_43 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_48: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_47: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_46: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_45: T.int32 = cse_var_46 + cse_var_48 + cse_var_47 + 516 +#CHECK-NEXT: C_1[cse_var_45:cse_var_45 + 2] = C_1[cse_var_45:cse_var_45 + 2] + T.Broadcast(_0_1[cse_var_46 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_48 + cse_var_47 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_48 + cse_var_47 + 516 + 2] #CHECK:CODE: 0 - diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py index ad7dd95ae..0429b0606 100644 --- a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py @@ -21,15 +21,11 @@ sch = impl.get_scheduler() spec = { - "DDR": { "k": {}, "i": {}, "j": {}, - }, - "R": { "i#i_inner": {"unroll": "i_unroll"}, "j#j_inner": {"vectorize": "j_vectorize"}, - }, } sample = {"i_inner": 2, "j_inner": 16, "i_unroll": 2, "j_vectorize": None, "order_R": ["j", "i"]} @@ -80,9 +76,9 @@ #CHECK-NEXT:O = obj['C'] #CHECK-NEXT:i, j, = O.op.axis #CHECK-NEXT:k, = O.op.reduce_axis -#CHECK-NEXT:j, j0 = sch[O].split(j, factor=16) #CHECK-NEXT:i, i0 = sch[O].split(i, factor=2) -#CHECK-NEXT:sch[O].reorder(k, i, j, j0, i0) +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=16) +#CHECK-NEXT:sch[O].reorder(k, i, j, i0, j0) #CHECK-NEXT:sch[O].unroll(i0) #CHECK-NEXT:sch[O].vectorize(j0) #CHECK-EMPTY: diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py index 0b1049662..b50d5b371 100644 --- a/tests/filecheck/search/test_matmul_descript_3axes.py +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -9,21 +9,15 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph, backend="tvm") spec = { - "DDR": { "j": {}, "k": {}, "i": {}, - "explore_axis_order": None, - }, - "R": { "j#jR": {}, "k#kR": {}, "i#iR": {}, - "explore_axis_order": None, - }, } strategy = Strategy(graph, spec, initialize=False) print(strategy._constraints) -# CHECK: ['iR || {21}', 'jR || {32}', 'kR || {12}', 'order_DDR in {0, 1, 2, 3, 4, 5}', 'order_R in {0, 1, 2, 3, 4, 5}'] +# CHECK:['iR || {21}', 'jR || {32}', 'kR || {12}'] diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py index ba6f16bf5..0c45e76c6 100644 --- a/tests/filecheck/search/test_matmul_descript_goto.py +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -9,29 +9,19 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = { - "DDRj": { "j": {"parallelize": "j_parallel"}, - }, - "DDR": { "k": {}, "i": {}, - "explore_axis_order": None, - "pack": [("pack_B", 1, True), ("pack_A", 0, True)], - }, - "L3": { + "pack": ("pack_B", 1, True), + "pack": ("pack_A", 0, True), "j#jL3": {}, - }, - "L2": { "i#iL2": {}, - }, - "L1": { "k#kL1": {"unroll": "k_unroll"}, - }, - "R": {"i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"}}, + "i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"} } constraint = ["iR * jR <= 56"] strategy = Strategy(graph, spec, constraints=constraint, initialize=False) print(strategy._constraints) -# CHECK: ['iL2 || {21}', 'iR * jR <= 56', 'iR || {iL2, 21}', 'jL3 || {32}', 'jR || {jL3, 32}', 'j_parallel in {0, 1}', 'j_vectorise in {0, 1}', 'kL1 || {12}', 'k_unroll || kL1', 'order_DDR in {0, 1}', 'pack_A in {0, 1}', 'pack_B in {0, 1}'] +# CHECK: ['iL2 || {21}', 'iR * jR <= 56', 'iR || {21, iL2}', 'jL3 || {32}', 'jR || {32, jL3}', 'j_parallel in {0, 1}', 'j_vectorise in {0, 1}', 'kL1 || {12}', 'k_unroll || kL1', 'pack_A in {0, 1}'] diff --git a/tests/filecheck/search/test_matmul_descript_simple.py b/tests/filecheck/search/test_matmul_descript_simple.py index 039f8d40a..5afcb94b4 100644 --- a/tests/filecheck/search/test_matmul_descript_simple.py +++ b/tests/filecheck/search/test_matmul_descript_simple.py @@ -5,23 +5,20 @@ import utils from xtc.search.strategies import Strategy_Descript as Strategy - +from xtc.schedules.descript_extend import DescriptExtend graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = { - "L3": { - "k": {}, - "i": {}, - "j": {}, - }, - "L2": { - "i#i1": {}, - "j#j1": {}, - }, - "L1": {"j#j2": {}}, + "k": {}, + "i": {}, + "j": {}, + "i#i1": {}, + "j#j1": {}, + "j#j2": {} } + strategy = Strategy(graph, spec, initialize=False) print(strategy._constraints) -# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {j1, 32}'] +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {32, j1}'] diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py index 735583ef7..881111ff3 100644 --- a/tests/filecheck/search/test_matmul_descript_split.py +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -9,33 +9,23 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = { - "DDR": { "j": {}, "k": {}, "i": {}, - }, - "L3": {"i#iL3": {}}, - "L2": { + "i#iL3": {}, "i#7": {}, - }, - "L1": { "j#jDDR": {}, "i[:5]": { - "R1": { - "i#iR1": {"unroll": None}, - "j#jR1": {"vectorize": None}, - }, + "i#iR1": {"unroll": None}, + "j#jR1": {"parallelize": None}, }, "i[5:]": { - "R2": { - "i#iR2": {"unroll": None}, - "j#jR2": {"vectorize": None}, - }, + "i#iR2": {"unroll": None}, + "j#jR2": {"parallelize": None}, }, - }, } strategy = Strategy(graph, spec, initialize=False) print(strategy._constraints) -# CHECK: ['iL3 || {21}', 'iR1 || {7, iL3, 21}', 'iR2 || {7, iL3, 21}', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}'] +# CHECK: ['iL3 || {21}', 'iR1 || {21, iL3, 7}', 'iR2 || {21, iL3, 7}', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}'] diff --git a/tests/filecheck/search/test_matmul_descript_split_in_split.py b/tests/filecheck/search/test_matmul_descript_split_in_split.py index 1e4942c4f..40569f97c 100644 --- a/tests/filecheck/search/test_matmul_descript_split_in_split.py +++ b/tests/filecheck/search/test_matmul_descript_split_in_split.py @@ -9,38 +9,26 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = { - "DDR": { "j": {}, "k": {}, "i": {}, - }, - "L3": { "i#iL2": {}, - }, - "L2": { "j#jDDR": {}, "i[:6]": { - "L1": {"i#3": {}}, - "R": { - "i[:2:]": { - "RR": { - "i#iR1": {"unroll": None}, - "j#jR1": {"vectorize": None}, - } - }, - "i[:iS:]": {"RR": {"i#iR3": {}, "j#jR3": {}}}, + "i#3": {}, + "i[:2:]": { + "i#iR1": {"unroll": None}, + "j#jR1": {"vectorize": None}, }, + "i[:iS:]": {"i#iR3": {}, "j#jR3": {}}, }, "i[6:]": { - "R": { - "i#iR2": {"unroll": None}, - "j#jR2": {"vectorize": None}, - }, - }, + "i#iR2": {"unroll": None}, + "j#jR2": {"vectorize": None}, }, } strategy = Strategy(graph, spec, initialize=False) print(strategy._constraints) -# CHECK: ['iL2 || {21}', 'iR1 || {3, iL2, 21}', 'iR2 || {iL2, 21}', 'iR3 || {3, iL2, 21}', 'iS + 2 == 3', 'i_1_ + 6 == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}', 'jR3 || {jDDR, 32}'] +# CHECK: ['iL2 || {21}', 'iR1 || {21, iL2, 3}', 'iR2 || {21, iL2}', 'iR3 || {21, iL2, 3}', 'iS + 2 == 3', 'i_1_ + 6 == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}', 'jR3 || {32, jDDR}'] diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py index ba6070d76..ebe57ee10 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_goto.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -20,19 +20,14 @@ backend = utils.get_backend(graph, "tvm") spec = """ -Memory: j: k: -L3: B: bufferize i: -L2: A: bufferize j#nc: i#mc: -L1: k#kc: unroll=kr -Register: i#mr: unroll full j#nr: vectorize full """ @@ -62,5 +57,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc <= 1024', 'kr <= kc', 'mc <= 1024', 'mr || {mc, 1024}', 'nc <= 1024', 'nr == 16 * nvr', 'nr || {nc, 1024}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] +# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc <= 1024', 'kr <= kc', 'mc <= 1024', 'mr || {1024, mc}', 'nc <= 1024', 'nr == 16 * nvr', 'nr || {1024, nc}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] #CHECK-NEXT: 100 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_simple.py b/tests/filecheck/search/test_matmul_descript_yaml_simple.py index ee6ef4713..e2fa5e172 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_simple.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_simple.py @@ -9,14 +9,11 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = """ -L3: k: i: j: -L2: i#i1: j#j1: -L1: j#j2: """ strategy = Strategy(graph, spec) @@ -24,5 +21,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {j1, 32}'] +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {32, j1}'] # CHECK-NEXT: 84 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_split.py b/tests/filecheck/search/test_matmul_descript_yaml_split.py index 82d441285..d07957c5d 100644 --- a/tests/filecheck/search/test_matmul_descript_yaml_split.py +++ b/tests/filecheck/search/test_matmul_descript_yaml_split.py @@ -9,24 +9,17 @@ graph = utils.get_graph_matmul() backend = utils.get_backend(graph) spec = """ -DDR: j: k: i: -L3: i#iL3: -L2: i#iL2: -L1: j#jDDR: i[:iS]: - R1: i#iR1: unroll j#jR1: vectorize - SR1: k#SR: i[iS:]: - R2: i#iR2: unroll j#jR2: unroll """ @@ -35,5 +28,5 @@ print(strategy._constraints) print(len(list(strategy.sample(100)))) -# CHECK: ['SR || {12}', 'iL2 || {iL3, 21}', 'iL3 || {21}', 'iR1 || {iL2, iL3, 21}', 'iR2 || {iL2, iL3, 21}', 'iS <= iL2', 'i_1_ + iS == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {jDDR, 32}', 'jR2 || {jDDR, 32}'] +# CHECK: ['SR || {12}', 'iL2 || {21, iL3}', 'iL3 || {21}', 'iR1 || {21, iL3, iL2}', 'iR2 || {21, iL3, iL2}', 'iS <= iL2', 'i_1_ + iS == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}'] # CHECK-NEXT: 100