diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..64637230 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/sampler"] + path = src/sampler + url = git@gitlab.inria.fr:CORSE/sampler.git diff --git a/requirements.txt b/requirements.txt index 38c09caf..e3fdb852 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,9 @@ ordered-set py-cpuinfo tqdm typing_extensions -xdsl~=0.50.0 pyyaml scikit-learn +networkx +sympy +strictyaml +types-PyYAML diff --git a/src/sampler b/src/sampler new file mode 160000 index 00000000..896a6108 --- /dev/null +++ b/src/sampler @@ -0,0 +1 @@ +Subproject commit 896a6108c6c62147c1c57ffa36a942cdd4e11ad0 diff --git a/src/xtc/__init__.py b/src/xtc/__init__.py index 0c0fdf80..118e14cc 100644 --- a/src/xtc/__init__.py +++ b/src/xtc/__init__.py @@ -3,5 +3,10 @@ # Copyright (c) 2024-2026 The XTC Project Authors # import importlib.metadata +import os +import sys __version__ = importlib.metadata.version("xtc") +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../sampler")) +) diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index f28b8243..0d4eb71f 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -13,6 +13,7 @@ from xtc.itf.schd.scheduler import Scheduler from xtc.schedules.descript import descript_scheduler +from xtc.schedules.descript_extend import descript_extend_scheduler from xtc.utils.xdsl_aux import parse_xdsl_module from xtc.backends.mlir.MlirNodeBackend import MlirNodeBackend from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend @@ -49,6 +50,7 @@ def main(): always_vectorize=args.always_vectorize, concluding_passes=args.concluding_passes, no_alias=not args.alias, + extend=args.extend, ) schedulers.append(sched) @@ -116,6 +118,7 @@ def build_node_scheduler( always_vectorize: bool, concluding_passes: list[str], no_alias: bool, + extend: bool, ) -> Scheduler: backend = build_mlir_node_backend( op=op, @@ -131,18 +134,73 @@ def build_node_scheduler( if "loop.schedule" in op.attributes: schedule_attribute = op.attributes.get("loop.schedule") assert isinstance(schedule_attribute, builtin.DictionaryAttr) - normal_schedule = normalize_schedule(schedule_attribute) - descript_scheduler( - scheduler=scheduler, - node_name=node_name, - abstract_axis=scheduler.backend.dims, - spec=normal_schedule, - ) + if extend: + normal_schedule = normalize_extend_schedule(schedule_attribute) + descript_extend_scheduler( + scheduler=scheduler, + node_name=node_name, + abstract_axis=scheduler.backend.dims, + abstract_axis_sizes={a: 1 for a in scheduler.backend.dims}, + spec=normal_schedule, + ) + else: + normal_schedule = normalize_schedule(schedule_attribute) + descript_scheduler( + scheduler=scheduler, + node_name=node_name, + abstract_axis=scheduler.backend.dims, + spec=normal_schedule, + ) op.attributes.pop("loop.schedule", None) return scheduler +def normalize_extend_schedule( + raw_schedule: builtin.DictionaryAttr, +) -> dict[str, dict]: + schedule: dict[str, Any] = {} + for declaration_, val_ in raw_schedule.data.items(): + assert isinstance(declaration_, str) + assert isinstance(val_, builtin.DictionaryAttr) + sub_schedule: dict[str, Any] = {} + for declaration, val in val_.data.items(): + assert isinstance(declaration, str) + if ":" in declaration: + if not isinstance(val, builtin.DictionaryAttr): + raise Exception( + f"The schedule within a split should be a dictionnary or void but got {declaration}" + ) + + assert isinstance(val, builtin.DictionaryAttr) + inner_schedule = normalize_extend_schedule(val) + sub_schedule[str(declaration)] = inner_schedule + else: + annotations: dict[str, int | None] = {} + if isinstance(val, builtin.DictionaryAttr): + for instr, param in val.data.items(): + assert isinstance(instr, str) + if isinstance(param, builtin.UnitAttr): + annotations[instr] = None + elif isinstance(param, builtin.IntegerAttr) or isinstance( + param, builtin.StringAttr + ): + annotations[instr] = param.value.data + else: + raise Exception( + "Annotation parameter should be void, int, or str." + ) + + elif not isinstance(val, builtin.UnitAttr): + raise Exception( + f"Annotation parameter should be a dict or void but got {type(val)}" + ) + + sub_schedule[declaration] = annotations + schedule[declaration_] = sub_schedule + return schedule + + def normalize_schedule( raw_schedule: builtin.DictionaryAttr, ) -> dict[str, dict]: @@ -328,6 +386,12 @@ def parse_args() -> argparse.Namespace: default=False, help="Print debug messages.", ) + parser.add_argument( + "--extend", + action="store_true", + default=False, + help="Use descript_extend instead of default", + ) args = parser.parse_args() diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 9433b98a..62297f72 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -41,10 +41,12 @@ class Annotations: parallelize: True if parallelization was requested. """ - unroll_factor: int | None = None + unroll_factor: int | str | None = None unroll_specified: bool = False - vectorize: bool = False - parallelize: bool = False + vectorize: bool | str = False + parallelize: bool | str = False + partial: bool = False + full: bool = False @dataclass(frozen=True) @@ -52,12 +54,15 @@ class SplitDecl: """AST Type: a split declaration like 'axis[start:end]'.""" axis: str - start: int | None - end: int | None + start: int | str | None + end: int | str | None body: ScheduleSpec + size: int | str | None = None @override def __str__(self) -> str: + if self.size is not None: + return f"{self.axis}[:{self.size}:]" start_str = "" if self.start is None else str(self.start) end_str = "" if self.end is None else str(self.end) decl = f"{self.axis}[{start_str}:{end_str}]" @@ -69,7 +74,7 @@ class TileDecl: """AST Type: a tile declaration like 'axis#size'.""" axis: str - size: int + size: int | str annotations: Annotations @override @@ -85,7 +90,36 @@ class AxisDecl: annotations: Annotations -ScheduleItem = SplitDecl | TileDecl | AxisDecl +@dataclass(frozen=True) +class FusionDecl: + """AST Type: a fusion declaration""" + + +@dataclass(frozen=True) +class PackDecl: + """AST Type: a packing declaration""" + + param: str | bool + input: str + pad: str | bool + + +@dataclass(frozen=True) +class BufferDecl: + """AST Type: a bufferisation declaration""" + + param: str | bool + pad: str + + +@dataclass(frozen=True) +class ExploreDecl: + level: str + + +ScheduleItem = ( + SplitDecl | TileDecl | AxisDecl | FusionDecl | PackDecl | BufferDecl | ExploreDecl +) @dataclass(frozen=True) @@ -144,10 +178,12 @@ def _parse_tile(self, declaration: str, value: dict) -> TileDecl: axis_name, size_str = parts - try: - size = int(size_str) - except ValueError: - raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") + size = int(size_str) if size_str.isnumeric() else size_str + + # try: + # size = int(size_str) + # except ValueError: + # raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") annotations = self._parse_annotations(value, declaration) return TileDecl(axis=axis_name, size=size, annotations=annotations) @@ -234,8 +270,8 @@ def _interpret_spec( slice = loop_nest.build_slice(root) # Track state during interpretation - sizes: dict[str, int] = {} - previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} + sizes: dict[str, int | str] = {} + previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} interchange: list[str] = list(head) for item in spec.items: @@ -267,7 +303,7 @@ def _interpret_split( loop_nest: LoopNest, root: str, interchange: list[str], - previous_cut: dict[str, int | None], + previous_cut: dict[str, int | str | None], ) -> None: """Interpret a split declaration.""" axis_name = item.axis @@ -283,10 +319,8 @@ def _interpret_split( # it is the previous cut if x is None: x = cut - assert x is not None - self._check_splitting_intervals(item, cut, x) - + assert x is not None # Update the previous cut previous_cut[axis_name] = y @@ -308,12 +342,15 @@ def _interpret_tile( item: TileDecl, slice: LoopNestSlice, interchange: list[str], - sizes: dict[str, int], + sizes: dict[str, int | str], ) -> str: """Interpret a tile declaration. Returns the loop name.""" self._check_axis_existence(item.axis) tile_num = len(slice.tiles[item.axis]) loop_name = f"{item.axis}{tile_num}" + if not isinstance(item.size, int): + raise ScheduleInterpretError(f"`{item}`: {item.size} is not an integer.") + assert isinstance(item.size, int) if item.size <= 0: raise ScheduleInterpretError( f"`{item}`: tile sizes should be strictly positive." @@ -354,7 +391,7 @@ def _apply_annotations( self, annotations: Annotations, loop_name: str, - sizes: dict[str, int], + sizes: dict[str, int | str], slice: LoopNestSlice, ) -> None: """Apply annotations to a loop in the slice.""" @@ -367,7 +404,7 @@ def _apply_annotations( f"{loop_name}'s size being unknown, an unroll factor is needed." ) unroll_factor = sizes[loop_name] - elif unroll_factor <= 0: + elif isinstance(unroll_factor, int) and unroll_factor <= 0: raise ScheduleInterpretError( f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' ) @@ -382,27 +419,46 @@ def _apply_annotations( def _check_splitting_intervals( self, item: SplitDecl, - cut: int | None, - x: int, - ) -> None: + cut: int | str | None, + x: int | str | None, + ) -> int | str | None: """Check that split intervals are valid and contiguous.""" - + y = item.end if cut is None: raise ScheduleInterpretError(f"{item}: {item.axis} already covered.") - if x > cut: - raise ScheduleInterpretError( - f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." - ) - elif x < cut: + if x is None: raise ScheduleInterpretError( - f"{item}: the segment begins at {x} but the previous one ends at {cut}." + f"x is None, but cut: {cut} is not, this should be unreachable." ) + if isinstance(x, int) and isinstance(cut, int): + if x > cut: + raise ScheduleInterpretError( + f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." + ) + elif x < cut: + raise ScheduleInterpretError( + f"{item}: the segment begins at {x} but the previous one ends at {cut}." + ) + else: + if x != cut: + raise ScheduleInterpretError( + f"{item}: Splitting ends at {cut} and begins at {x}. These need to be the same." + ) + if y is None: + return None - if item.end is not None and x >= item.end: - raise ScheduleInterpretError( - f"{item}: the ending point should be greater than the starting point." - ) + if isinstance(x, int): + if isinstance(y, int): + if x >= y: + raise ScheduleInterpretError( + f"{item}: the ending point should be greater than the starting point." + ) + else: + return y - x + if x == 0: + return y + return None @dataclass @@ -477,12 +533,12 @@ class LoopNestSlice: """ root: str - tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int]] = field(default_factory=dict) + tiles: dict[str, dict[str, int | str]] + splits: dict[str, dict[str, int | str]] = field(default_factory=dict) interchange: list[str] = field(default_factory=list) vectorize: list[str] = field(default_factory=list) parallelize: list[str] = field(default_factory=list) - unroll: dict[str, int] = field(default_factory=dict) + unroll: dict[str, int | str] = field(default_factory=dict) @property def splits_to_sizes(self) -> dict[str, int]: @@ -490,6 +546,7 @@ def splits_to_sizes(self) -> dict[str, int]: for axis in self.splits: last_start = None for loop_name, start in reversed(self.splits[axis].items()): + assert isinstance(start, int) if last_start is not None: size_of_split = last_start - start splits_to_sizes[loop_name] = size_of_split @@ -501,9 +558,38 @@ def tiles_to_sizes(self) -> dict[str, int]: tiles_to_sizes: dict[str, int] = {} for tiles in self.tiles.values(): for loop, size in tiles.items(): + assert isinstance(size, int) tiles_to_sizes[loop] = size return tiles_to_sizes + @property + def int_tiles(self) -> dict[str, dict[str, int]]: + return self._int_dict(self.tiles) + + @property + def int_splits(self) -> dict[str, dict[str, int]]: + return self._int_dict(self.splits) + + @property + def int_unroll(self) -> dict[str, int]: + out = {} + for x, v in self.unroll.items(): + if isinstance(v, str) and v.isnumeric(): + v = int(v) + assert isinstance(v, int) + out[x] = v + return out + + def _int_dict(self, input: dict[str, dict[str, Any]]) -> dict[str, dict[str, int]]: + out: dict[str, dict[str, int]] = {} + for x, v in input.items(): + v_dict: dict[str, int] = {} + for x_v, v_v in v.items(): + assert isinstance(v_v, int) + v_dict[x_v] = v_v + out[x] = v_dict + return out + @dataclass class LoopNest: @@ -568,7 +654,9 @@ def _check_tiling_consistency(self) -> None: `{axis}#{size}`: {axis} has not been materialized yet. """ ) - seen_axes[axis] = sched.tiles[axis][loop_name] + loop_size = sched.tiles[axis][loop_name] + assert isinstance(loop_size, int) + seen_axes[axis] = loop_size def _check_sizes(self): mapper = LoopsDimsMapper.build_from_slices(self.slices) @@ -589,6 +677,7 @@ def _check_sizes(self): current_size_of_split[loop_name] = None elif loop_name in mapper.tiles_to_axis: loop_size = sched.tiles[axis][loop_name] + assert isinstance(loop_size, int) LoopNest._must_be_smaller_routine( new_size=loop_size, current_sizes=current_sizes, @@ -611,6 +700,7 @@ def _check_sizes(self): if loop_name in sched.unroll: unroll_factor = sched.unroll[loop_name] + assert isinstance(unroll_factor, int) if loop_size and loop_size < unroll_factor: raise ScheduleValidationError( f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' @@ -645,11 +735,11 @@ def descript_scheduler( abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]). spec: The schedule specification as a nested dict. """ - descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis) - descript.apply(node_name=node_name, spec=spec) + descript = Descript(abstract_axis=abstract_axis) + descript.apply(scheduler=scheduler, node_name=node_name, spec=spec) -@dataclass(frozen=True) +@dataclass(frozen=False) class Descript: """Applies a parsed and interpreted schedule to a Scheduler. @@ -661,10 +751,11 @@ class Descript: 4. Apply: LoopNest -> Scheduler """ - scheduler: Scheduler abstract_axis: list[str] - def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: + def apply( + self, node_name: str, spec: dict[str, dict[str, Any]], scheduler: Scheduler + ) -> None: """Parse, interpret, validate, and apply a schedule specification. Args: @@ -688,22 +779,22 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: loop_nest.check() # Apply the schedule to the scheduler - self._apply_loop_nest(loop_nest) + self._apply_loop_nest(loop_nest, scheduler) - def _apply_loop_nest(self, loop_nest: LoopNest) -> None: + def _apply_loop_nest(self, loop_nest: LoopNest, scheduler: Scheduler) -> None: """Apply a LoopNest to the scheduler.""" - self.scheduler.set_dims(self.abstract_axis) + scheduler.set_dims(self.abstract_axis) for slice in loop_nest.slices: root = slice.root - for d, s in slice.splits.items(): - self.scheduler.split(d, s, root=root) + for d, s in slice.int_splits.items(): + scheduler.split(d, s, root=root) - for d, s in slice.tiles.items(): - self.scheduler.tile(d, s, root=root) + for d, s in slice.int_tiles.items(): + scheduler.tile(d, s, root=root) - self.scheduler.interchange(slice.interchange, root=root) - self.scheduler.vectorize(slice.vectorize, root=root) - self.scheduler.parallelize(slice.parallelize, root=root) - self.scheduler.unroll(slice.unroll, root=root) + scheduler.interchange(slice.interchange, root=root) + scheduler.vectorize(slice.vectorize, root=root) + scheduler.parallelize(slice.parallelize, root=root) + scheduler.unroll(slice.int_unroll, root=root) diff --git a/src/xtc/schedules/descript_extend.py b/src/xtc/schedules/descript_extend.py new file mode 100644 index 00000000..a443eba8 --- /dev/null +++ b/src/xtc/schedules/descript_extend.py @@ -0,0 +1,720 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from typing import Any +from copy import deepcopy +from dataclasses import dataclass, field +import re + +import strictyaml +from typing_extensions import override + +from xtc.itf.schd.scheduler import Scheduler + +from xtc.schedules.descript import ( + Annotations, + AxisDecl, + BufferDecl, + Descript, + FusionDecl, + LoopNest, + LoopNestSlice, + PackDecl, + ScheduleInterpretError, + ScheduleInterpreter, + ScheduleItem, + ScheduleParseError, + ScheduleParser, + ScheduleSpec, + SplitDecl, + TileDecl, +) + + +@dataclass +class LoopNestSliceExtend(LoopNestSlice): + axis_orders: list[str] = field(default_factory=list) + axes: dict[str, dict] = field(default_factory=dict) + packs: dict[str, list] = field(default_factory=dict) + buffers: dict[str, list] = field(default_factory=dict) + fusions: dict[str, list] = field(default_factory=dict) + variables: set[str] = field(default_factory=set) + constraints: set[str] = field(default_factory=set) + vectorize_bool: set[tuple[str, str]] = field(default_factory=set) + parallelize_bool: set[tuple[str, str]] = field(default_factory=set) + + +@dataclass +class LoopNestExtend(LoopNest): + @override + def build_slice(self, root: str) -> LoopNestSliceExtend: + slice = LoopNestSliceExtend( + root=root, tiles={a: {} for a in self.abstract_dims} + ) + self.slices = [slice] + self.slices + return slice + + def apply_sample(self, sample: dict[str, Any]): + for schedule in self.slices: + for dim, axes in schedule.splits.items(): + for level, size in axes.items(): + if isinstance(size, str): + schedule.splits[dim][level] = sample[size] + for dim, axes in schedule.tiles.items(): + for level, size in axes.items(): + if isinstance(size, str): + schedule.tiles[dim][level] = sample[size] + for axis, size in schedule.unroll.items(): + if isinstance(size, str): + val = sample[size] + if val is None: + for s__ in schedule.tiles.values(): + for level, size in s__.items(): + if axis == level: + val = size + break + if val is not None: + break + schedule.unroll[axis] = val + if isinstance(schedule, LoopNestSliceExtend): + for axis, loop in schedule.vectorize_bool: + axis = sample.get(axis, False) + if axis is None or axis: + schedule.vectorize.append(loop) + for axis, loop in schedule.parallelize_bool: + axis = sample.get(axis, False) + if axis is None or axis: + schedule.parallelize.append(loop) + for dim, packs in schedule.packs.items(): + for i, (flag, input, pad) in enumerate(packs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule.packs[dim].pop(i) + continue + if isinstance(input, str): + input = sample.get(input, input) + sample_flag = True + if sample_flag: + schedule.packs[dim][i] = (flag, input, pad) + for dim, buffs in schedule.buffers.items(): + for i, (flag, pad) in enumerate(buffs): + sample_flag = False + if isinstance(flag, str): + flag = sample.get(flag, False) + sample_flag = True + if not flag: + schedule.buffers[dim].pop(i) + continue + if sample_flag: + schedule.buffers[dim][i] = (flag, pad) + for dim, axes in schedule.axes.items(): + d_holder = f"order_{dim}" + s = sample.get(d_holder, None) + if s: + sch = {} + for a in s: + sch[a] = axes[a] + schedule.axes[dim] = sch + + +def descript_extend_scheduler( + scheduler: Scheduler, + node_name: str, + abstract_axis: list[str], + abstract_axis_sizes: dict[str, int], + spec: dict[str, dict], + abstract_matrix: list[str] = [], + sample: dict[str, Any] = {}, + partial_tiles: bool = False, + partial_unrolls: bool = False, +): + descript = DescriptExtend( + abstract_axis=abstract_axis, + abstract_axis_sizes=abstract_axis_sizes, + abstract_matrix=abstract_matrix, + partial_tiles=partial_tiles, + partial_unrolls=partial_unrolls, + ) + descript.apply(node_name=node_name, spec=spec, sample=sample, scheduler=scheduler) + + +class ScheduleParserExtend(ScheduleParser): + _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\w+|\w*)?:(-\w+|\w*)?\]$") + _SPLIT_MIDDLE_PATTERN = re.compile(r"^(.*)\[:(\w*):\]$") + + @override + def __init__(self, abstract_axis: list[str], abstract_matrix: list[str]): + self.abstract_matrix = abstract_matrix + super().__init__(abstract_axis) + + @override + def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: + if "fusion" == declaration: + return self._parse_fusion() + if "pack" == declaration: + return self._parse_pack(value) + if "buffer" == declaration: + return self._parse_buffer(value) + if declaration in self.abstract_matrix: + return self._parse_matrix(declaration, value) + + return super()._parse_declaration(declaration, value) + + @override + def _parse_split(self, declaration: str, value: dict) -> SplitDecl: + axis_name, start, end, size = self._parse_split_syntax_extend(declaration) + + body = self.parse(value) + return SplitDecl(axis=axis_name, start=start, end=end, body=body, size=size) + + @override + def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: + """Parse annotation dict into Annotations object.""" + + unroll_factor: int | str | None = None + unroll_specified = False + vectorize = False + parallelize = False + partial = False + full = False + + for key, param in value.items(): + if key == "unroll": + if param is True or param is None: + unroll_factor = None + unroll_specified = True + elif param is False: + pass + elif isinstance(param, int) or isinstance(param, str): + unroll_factor = param + unroll_specified = True + else: + raise ScheduleParseError( + f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, None, or an integer.' + ) + elif key == "vectorize": + if param is True or param is None: + vectorize = True + elif param is False: + pass + elif isinstance(param, str): + vectorize = param + else: + raise ScheduleParseError( + f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' + ) + elif key == "parallelize": + if isinstance(param, str): + parallelize = param + elif param is not None: + raise ScheduleParseError( + f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' + ) + else: + parallelize = True + elif key == "partial": + if full: + raise ScheduleParseError("Tile cannot be full and partial.") + partial = True + elif key == "full": + if partial: + raise ScheduleParseError("Tile cannot be partial and full.") + full = True + else: + raise ScheduleParseError(f"Unknown annotation on {context}: {key}") + + return Annotations( + unroll_factor=unroll_factor, + unroll_specified=unroll_specified, + vectorize=vectorize, + parallelize=parallelize, + partial=partial, + full=full, + ) + + def _parse_split_syntax_extend( + self, declaration: str + ) -> tuple[str, int | str | None, int | str | None, int | str | None]: + """Parse the syntax of a split declaration.""" + match = self._SPLIT_PATTERN.match(declaration) + if not match: + match = self._SPLIT_MIDDLE_PATTERN.match(declaration) + if not match: + raise ScheduleParseError(f"Wrong format {declaration}") + prefix, z = match.groups() + z = int(z) if z.isnumeric() else z + return prefix, None, None, z + + prefix, x_str, y_str = match.groups() + x = int(x_str) if x_str.isnumeric() else x_str + y = int(y_str) if y_str.isnumeric() else y_str + x = x if x else None + y = y if y else None + return prefix, x, y, None + + def _parse_fusion(self) -> FusionDecl: + return FusionDecl() + + def _parse_pack(self, value: Any) -> PackDecl: + assert len(value) == 3 + param, input, pad = value + return PackDecl(param, input, pad) + + def _parse_buffer(self, value: Any) -> BufferDecl: + assert len(value == 2) + param, pad = value + return BufferDecl(param, pad) + + def _parse_matrix(self, declaration: str, value: Any) -> PackDecl | BufferDecl: + param = value.get("bufferize", False) + if not (param is None or param): + raise ScheduleParseError( + f"Declared matrix {declaration} without bufferization." + ) + pad = value.get("pad", False) + if declaration == self.abstract_matrix[-1]: + return BufferDecl(param, pad) + return PackDecl(param, declaration, pad) + + +class ScheduleInterpreterExtend(ScheduleInterpreter): + @override + def __init__( + self, + abstract_axis: list[str], + abstract_axis_sizes: dict[str, int], + abstract_matrix: list[str], + partial_tiles: bool = False, + partial_unrolls: bool = False, + ): + self.abstract_matrix = abstract_matrix + self.abstract_axis_sizes = abstract_axis_sizes + self.partial_tiles = partial_tiles + self.partial_unrolls = partial_unrolls + super().__init__(abstract_axis) + + @override + def interpret(self, spec: ScheduleSpec, root: str) -> LoopNestExtend: + return self._interpret_spec(spec, root, head=[]) + + @override + def _interpret_spec( + self, + spec: ScheduleSpec, + root: str, + head: list[str], + tile_sizes: dict[str, list[int | str]] | None = None, + ) -> LoopNestExtend: + """Interpret a schedule spec recursively.""" + loop_nest = LoopNestExtend(abstract_dims=self.abstract_axis) + slice = loop_nest.build_slice(root) + + # Track state during interpretation + last_split: list[tuple[int | str, int | str]] = [] + previous_cut: dict[str, int | str | None] = {a: 0 for a in self.abstract_axis} + interchange: list[str] = list(head) + axes_sizes: dict[str, list[int | str]] = {} + sizes: dict[str, int | str] = {} + if tile_sizes: + axes_sizes = tile_sizes + else: + axes_sizes = {a: [v] for a, v in self.abstract_axis_sizes.items()} + + for item in spec.items: + if isinstance(item, SplitDecl): + self._interpret_split( + item=item, + slice=slice, + loop_nest=loop_nest, + root=root, + interchange=interchange, + previous_cut=previous_cut, + axes_sizes=axes_sizes, + last_split=last_split, + ) + elif isinstance(item, TileDecl): + loop_name = self._interpret_tile( + item=item, + slice=slice, + interchange=interchange, + axes_sizes=axes_sizes, + sizes=sizes, + ) + self._apply_annotations(item.annotations, loop_name, sizes, slice) + elif isinstance(item, AxisDecl): + loop_name = self._interpret_axis(item, interchange) + self._apply_annotations(item.annotations, loop_name, sizes, slice) + elif isinstance(item, FusionDecl): + ... + elif isinstance(item, PackDecl): + self._interpret_pack(slice, interchange[-1], item) + elif isinstance(item, BufferDecl): + self._interpret_buffer(slice, interchange[-1], item) + + if len(last_split) > 0: + a, b = last_split[0] + if isinstance(a, int) and not isinstance(b, int): + a, b = b, a + a, b = str(a), str(b) + for c in slice.constraints: + slice.constraints.remove(c) + slice.constraints.add(c.replace(a, b)) + + # Check that all splits are complete + for axis, cut in previous_cut.items(): + if ( + cut is not None + and isinstance(cut, int) + and cut not in [0, axes_sizes[axis][-1]] + ): + raise ScheduleInterpretError( + f"Splitting of {axis} unachieved (stops at {cut})." + ) + + slice.interchange = interchange + return loop_nest + + @override + def _interpret_split( + self, + item: SplitDecl, + slice: LoopNestSlice, + loop_nest: LoopNest, + root: str, + interchange: list[str], + previous_cut: dict[str, int | str | None], + axes_sizes: dict[str, list[int | str]] = {}, + last_split: list[tuple[int | str, int | str]] = [], + ): + """Interpret a split declaration.""" + if not isinstance(slice, LoopNestSliceExtend): + return super()._interpret_split( + item, slice, loop_nest, root, interchange, previous_cut + ) + axis_name = item.axis + self._check_axis_existence(axis_name) + x = item.start + y = item.end + z = item.size + + # The only declaration where y (the cut) is None is the + # last one, so it cannot be the previous one. + cut = previous_cut[axis_name] + current_size = axes_sizes[axis_name][-1] + + # When x (the starting point of the slice) is not specified, + # it is the previous cut + if x is None: + x = cut + inner_size = self._check_splitting_intervals(item, cut, x) + assert x is not None + + # Save the cutting points of the new dimensions + if axis_name not in slice.splits: + slice.splits[axis_name] = {} + new_dim_index = len(slice.splits[axis_name]) + new_dim_name = f"{axis_name}[{new_dim_index}]" + new_root_name = f"{root}/{new_dim_name}" + interchange.append(new_dim_name) + + if z is None: + # Update the previous cut + previous_cut[axis_name] = y + slice.splits[axis_name][new_dim_name] = x + inner_size = None + if y is None: + y = current_size + if isinstance(x, int): + if x == 0: + inner_size = y + elif isinstance(y, int): + inner_size = y - x + if inner_size is None: + inner_size = root[1:] + new_dim_name + inner_size = ( + inner_size.replace("/", "").replace("[", "_").replace("]", "_") + ) + slice.constraints.add(f"{inner_size} <= {y}") + if isinstance(x, str): + slice.constraints.add(f"{x} <= {y}") + slice.constraints.add(f"{inner_size} + {x} == {y}") + else: + inner_size = z + x = cut + y = current_size + assert x is not None + slice.splits[axis_name][new_dim_name] = x + if isinstance(z, int) and isinstance(x, int): + previous_cut[axis_name] = x + z + if not isinstance(y, int): + slice.constraints.add(f"{z + x} <= {y}") + elif isinstance(x, int) and x == 0: + previous_cut[axis_name] = z + if not isinstance(y, int): + slice.constraints.add(f"{z} <= {y}") + else: + new_cut = root[1:] + new_dim_name + new_cut = new_cut.replace("/", "").replace("[", "_").replace("]", "_") + previous_cut[axis_name] = new_cut + if len(last_split) > 0: + a, b = last_split[0] + slice.constraints.add(f"{a} <= {b}") + last_split.append((new_cut, y)) + slice.constraints.add(f"{z} + {x} == {new_cut}") + + # Recursively interpret the nested schedule + inner_nest = self._interpret_spec( + spec=item.body, + root=new_root_name, + head=[axis_name], + tile_sizes=deepcopy(axes_sizes), + ) + loop_nest.slices += inner_nest.slices + + @override + def _interpret_tile( + self, + item: TileDecl, + slice: LoopNestSlice, + interchange: list[str], + sizes: dict[str, int | str], + axes_sizes: dict[str, list[int | str]] = {}, + ) -> str: + """Interpret a tile declaration. Returns the loop name.""" + self._check_axis_existence(item.axis) + tile_num = len(slice.tiles[item.axis]) + loop_name = f"{item.axis}{tile_num}" + if isinstance(item.size, int) and item.size <= 0: + raise ScheduleInterpretError( + f"`{item}`: tile sizes should be strictly positive." + ) + slice.tiles[item.axis][loop_name] = item.size + size_list = axes_sizes[item.axis] + sizes[loop_name] = item.size + assert isinstance(size_list, list) + old_size = size_list[-1] + interchange.append(loop_name) + if isinstance(item.size, str): + assert isinstance(slice, LoopNestSliceExtend) + slice.variables.add(item.size) + partial = item.annotations.partial + full = item.annotations.full + if partial or (not full and self.partial_tiles): + slice.constraints.add(f"{item.size} <= {old_size}") + else: + s = ( + ", ".join(map(str, size_list)) + if len(size_list) > 1 + else str(size_list[0]) + ) + s = f"{item.size} || {{{s}}}" + slice.constraints.add(s) + size_list.append(item.size) + return loop_name + + @override + def _apply_annotations( + self, + annotations: Annotations, + loop_name: str, + sizes: dict[str, int | str], + slice: LoopNestSlice, + ) -> None: + """Apply annotations to a loop in the slice.""" + assert isinstance(slice, LoopNestSliceExtend) + if annotations.unroll_specified: + unroll_factor = annotations.unroll_factor + if unroll_factor is None: + # None means "unroll fully" - use the loop size + if loop_name not in sizes: + raise ScheduleInterpretError( + f"{loop_name}'s size being unknown, an unroll factor is needed." + ) + unroll_factor = sizes[loop_name] + elif isinstance(unroll_factor, str): + slice.variables.add(unroll_factor) + if self.partial_unrolls: + slice.constraints.add(f"{unroll_factor} <= {sizes[loop_name]}") + else: + slice.constraints.add(f"{unroll_factor} || {sizes[loop_name]}") + elif unroll_factor <= 0: + raise ScheduleInterpretError( + f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' + ) + slice.unroll[loop_name] = unroll_factor + + vectorize = annotations.vectorize + if vectorize: + slice.vectorize.append(loop_name) + if isinstance(vectorize, str): + slice.variables.add(vectorize) + slice.constraints.add(f"{vectorize} in {{0, 1}}") + + parallelize = annotations.parallelize + if parallelize: + slice.parallelize.append(loop_name) + if isinstance(parallelize, str): + slice.variables.add(parallelize) + slice.constraints.add(f"{parallelize} in {{0, 1}}") + + def _interpret_pack( + self, slice: LoopNestSliceExtend, loop_name: str, item: PackDecl + ): + param, input, pad = item.param, item.input, item.pad + if isinstance(param, str): + slice.variables.add(param) + slice.constraints.add(f"{param} in {{0, 1}}") + if isinstance(pad, str): + slice.variables.add(pad) + slice.constraints.add(f"{pad} in {{0, 1}}") + if loop_name in slice.packs: + slice.packs[loop_name].append((param, input, pad)) + else: + slice.packs[loop_name] = [(param, input, pad)] + + def _interpret_buffer( + self, slice: LoopNestSliceExtend, loop_name: str, item: BufferDecl + ): + param, pad = item.param, item.pad + if isinstance(param, str): + slice.variables.add(param) + slice.constraints.add(f"{param} in {{0, 1}}") + if isinstance(pad, str): + slice.variables.add(pad) + slice.constraints.add(f"{pad} in {{0, 1}}") + slice.buffers[loop_name].append((param, pad)) + + +@dataclass(frozen=False) +class DescriptExtend(Descript): + abstract_axis_sizes: dict[str, int] + abstract_matrix: list[str] + partial_tiles: bool = False + partial_unrolls: bool = False + _loop_nest: None | LoopNestExtend = None + + @override + def apply( + self, + node_name: str, + spec: str | dict[str, dict[str, Any]], + scheduler: Scheduler, + sample: dict[str, Any] = {}, + ) -> None: + """Parse, interpret, validate, and apply a schedule specification. + + Args: + node_name: The name of the root node to schedule. + spec: The schedule specification as a nested dict. + Raises: + ScheduleParseError: If the spec cannot be parsed. + ScheduleInterpretError: If the spec cannot be interpreted. + ScheduleValidationError: If the resulting schedule is invalid. + """ + + if isinstance(spec, str): + spec = self.parse_yaml(spec) + + # Parse the specification into an AST + parser = ScheduleParserExtend(self.abstract_axis, self.abstract_matrix) + ast = parser.parse(spec) + + # Interpret the AST into a LoopNest + interpreter = ScheduleInterpreterExtend( + self.abstract_axis, self.abstract_axis_sizes, self.abstract_matrix + ) + loop_nest = interpreter.interpret(ast, root=node_name) + + if sample != {}: + loop_nest.apply_sample(sample) + + # Validate the loop nest + loop_nest.check() + for slice in loop_nest.slices: + assert isinstance(slice, LoopNestSliceExtend) + + # Apply the schedule to the scheduler + self._apply_loop_nest(loop_nest, scheduler) + + def loop_nest( + self, node_name: str, spec: str | dict[str, dict[str, Any]] + ) -> LoopNestExtend: + if self._loop_nest: + return self._loop_nest + + if isinstance(spec, str): + spec = self.parse_yaml(spec) + + parser = ScheduleParserExtend(self.abstract_axis, self.abstract_matrix) + ast = parser.parse(spec) + + # Interpret the AST into a LoopNest + interpreter = ScheduleInterpreterExtend( + abstract_axis=self.abstract_axis, + abstract_axis_sizes=self.abstract_axis_sizes, + abstract_matrix=self.abstract_matrix, + partial_tiles=self.partial_tiles, + partial_unrolls=self.partial_unrolls, + ) + self._loop_nest = interpreter.interpret(ast, root=node_name) + return self._loop_nest + + def apply_sample( + self, loop_nest: LoopNestExtend, scheduler: Scheduler, sample: dict[str, Any] + ): + loop_nest = deepcopy(loop_nest) + if sample != {}: + loop_nest.apply_sample(sample) + + # Validate the loop nest + loop_nest.check() + + # Apply the schedule to the scheduler + self._apply_loop_nest(loop_nest, scheduler) + + def parse_yaml(self, spec: str) -> dict[str, dict]: + dspec = strictyaml.load(spec).data + assert isinstance(dspec, dict) + return self._parse_yaml(dspec) + + def _parse_yaml(self, spec: dict[str, dict]) -> dict[str, dict]: + out_dict = {} + for a, v in spec.items(): + if a in self.abstract_matrix: + assert isinstance(v, str) + out_dict[a] = self._split_yaml(v) + else: + if isinstance(v, str): + d = self._split_yaml(v) + else: + assert isinstance(v, dict) + d = v + size = d.get("size", None) + if size: + d.pop("size") + a = f"{a}#{size}" + if ":" in a: + out_dict[a] = self._parse_yaml(d) + continue + out_dict[a] = {} + for axis_arg, arg_val in d.items(): + out_dict[a][axis_arg] = arg_val + return out_dict + + def _split_yaml(self, s: str) -> dict[str, Any]: + d = {} + for s in s.split(): + if "=" not in s: + d[s] = None + else: + x, y = s.split("=") + try: + tmp = eval(y) + except (NameError, SyntaxError): + tmp = y + d[x] = tmp + return d diff --git a/src/xtc/search/strategies.py b/src/xtc/search/strategies.py index 72d10bf9..73b5cc63 100644 --- a/src/xtc/search/strategies.py +++ b/src/xtc/search/strategies.py @@ -9,9 +9,18 @@ import itertools import numpy as np +from xvs.properties import constraints_from_str, hypergraph, Context +from xvs.strategy import ( + execute_dynamic, + execute_static, + solve_with_z3, + pretty_print_methods, +) from xtc.itf.graph import Graph from xtc.itf.schd import Scheduler +from xtc.itf.schd.scheduler import DEFAULT_ROOT from xtc.itf.search import Sample, Strategy +from xtc.schedules.descript_extend import DescriptExtend, LoopNestSliceExtend from xtc.utils.math import ( factors_to_sizes, factors_enumeration, @@ -20,7 +29,6 @@ sample_uniques, ) - __all__ = [ "Strategies", ] @@ -941,6 +949,141 @@ def _filter(self, samples: Iterator[VecSample]) -> Iterator[VecSample]: yield x +class Strategy_Descript(Strategy): + def __init__( + self, + graph: Graph, + spec: dict[str, dict[str, Any]] | str, + constraints: list[str] = [], + partial_tiles: bool = False, + partial_unrolls: bool = False, + initialize: bool = True, + ) -> None: + self._graph = graph + self._op = graph.outputs_nodes[0].operation + self._stats: dict[str, int] = {} + self._axes = list(self._op.dims) + self._sizes = self._constant_sizes() + self._sample_names: list[str] = [] + descript = DescriptExtend( + abstract_axis=self._axes, + abstract_axis_sizes=dict(self._sizes), + abstract_matrix=["A", "B", "C"], + partial_tiles=partial_tiles, + partial_unrolls=partial_unrolls, + ) + self._descript = descript + self._initialized = False + loop_nest = descript.loop_nest(node_name=DEFAULT_ROOT, spec=spec) + self._loop_nest = loop_nest + input_constraints: list[str] = [] + for slice in loop_nest.slices: + assert isinstance(slice, LoopNestSliceExtend) + input_constraints += slice.constraints + self._sample_names += slice.variables + for a, v in self._sizes.items(): + for i, s in enumerate(constraints): + assert isinstance(s, str) + constraints[i] = s.replace(f"[{a}]", str(v)) + self._orders: dict[str, list] = {} + self._constraints = constraints + input_constraints + self._constraints.sort() + if initialize: + self._initialize() + + def _initialize(self): + if self._initialized: + return + max_enum = int(1 + np.log2(max(self._sizes.values()))) + context = Context() + constraints, self.constrants = constraints_from_str( + self._constraints, context=context + ) + properties, constraints = hypergraph( + constraints, max_enum=max_enum, context=context + ) + methods = solve_with_z3(list(context.variables.keys()), properties, constraints) + enumerations = execute_static(methods, properties, constraints) + self._context = context + self._properties = properties + self._z3_constraints = constraints + self._methods = methods + self._enumerations = enumerations + self._initialized = True + + @property + @override + def graph(self) -> Graph: + return self._graph + + @override + def generate(self, scheduler: Scheduler, sample: Sample) -> None: + descript = self._descript + # for a, p in self._orders.items(): + # if a in sample: + # if isinstance(sample[a], int): + # sample[a] = p[sample[a]] + descript.apply_sample( + loop_nest=self._loop_nest, scheduler=scheduler, sample=sample + ) + + @override + def sample(self, num: int, seed: int | None = 0) -> Iterator[Sample]: + samples = sample_uniques(self._sample_once_tuple, num) + for x in samples: + yield dict(zip(self.sample_names, x)) + + def sample_once(self, num: int) -> Iterator[Sample]: + self._initialize() + draw = execute_dynamic( + self._methods, + self._properties, + self._z3_constraints, + self._enumerations, + k=num, + ) + return draw + + def pretty_print_methods(self, tab: str = "\t"): + self._initialize() + pretty_print_methods( + self._methods, self._properties, self._z3_constraints, tab=tab + ) + + def _sample_once_tuple(self, num: int) -> Iterator[tuple]: + draw = self.sample_once(num) + for d in draw: + yield tuple([d[x] for x in self.sample_names]) + + @override + def exhaustive(self) -> Iterator[Sample]: + return self.sample(1000) + + @override + def default_schedule(self, opt_level: int = 2) -> Sample: + while True: + for x in self.sample(1): + if x: + return x + + def _constant_sizes(self) -> Mapping[str, int]: + sizes = {a: v for a, v in self._op.dims.items() if isinstance(v, int)} + return sizes + + @override + def dict_to_sample(self, sample: dict[str, Any]) -> Sample: + return sample + + @override + def sample_to_dict(self, sample: Sample) -> dict[str, int]: + return sample + + @property + @override + def sample_names(self) -> list[str]: + return self._sample_names + + class Strategies: @classmethod def names(cls) -> Sequence[str]: diff --git a/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir new file mode 100644 index 00000000..67bc2f16 --- /dev/null +++ b/tests/filecheck/mlir_loop/descript_syntax/splitting/v_splitting_extend.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-loop --extend --no-alias --print-source-ir %s 2>&1 | filecheck %s + +func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<256x256xf64>){ + linalg.matmul { + loop.dims = ["i", "j"], + loop.schedule = { + "i[:5]" = { "j" }, + "i[5:]" = { "j" } + } + } + ins(%A, %B : memref<256x512xf64>, memref<512x256xf64>) + outs(%C: memref<256x256xf64>) + return +} +// CHECK: module attributes {transform.with_named_sequence} { +// CHECK-NEXT: func.func @matmul(%arg0: memref<256x512xf64> {llvm.noalias}, %arg1: memref<512x256xf64> {llvm.noalias}, %arg2: memref<256x256xf64> {llvm.noalias}) { +// CHECK-NEXT: linalg.matmul {__node0__} ins(%arg0, %arg1 : memref<256x512xf64>, memref<512x256xf64>) outs(%arg2 : memref<256x256xf64>) +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +// CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +// CHECK-NEXT: %0 = transform.structured.match attributes {__node0__} in %arg0 : (!transform.any_op) -> !transform.any_op +// CHECK-NEXT: %1 = transform.structured.split %0 after 5 {dimension = 0 : i64} : !transform.any_op +// CHECK-NEXT: %2:2 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %2#0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops "__node0__/i[0]/i" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_1 "__node0__/i[0]/j" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %2#1 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_3 "__node0__/i[1]/i" : !transform.any_op +// CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +// CHECK-NEXT: transform.annotate %loops_5 "__node0__/i[1]/j" : !transform.any_op +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT:} diff --git a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir index dff01ca7..b2171be0 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_invalide_argument.mlir @@ -15,4 +15,5 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 outs(%C: memref<256x256xf64>) return } -// CHECK: `j#a`: a is not an integer. +// CHECK: xtc.schedules.descript.ScheduleInterpretError: `j#a`: a is not an integer. + diff --git a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir index 4c6eae3b..ab21ab0e 100644 --- a/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir +++ b/tests/filecheck/mlir_loop/descript_syntax/tiling/i_one_axis_positive_negative_tiling.mlir @@ -14,4 +14,4 @@ func.func @matmul(%A: memref<256x512xf64>, %B: memref<512x256xf64>, %C: memref<2 outs(%C: memref<256x256xf64>) return } -// CHECK: `k#-1`: tile sizes should be strictly positive. +// CHECK: xtc.schedules.descript.ScheduleInterpretError: `k#-1`: -1 is not an integer. diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py new file mode 100644 index 00000000..a226cff3 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_sample.py @@ -0,0 +1,163 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + spec={ + "k": {}, + "i": {}, + "j": {}, + "i#i_inner": {"unroll": "i_unroll"}, + "j#j_inner": {"vectorize": "j_vectorize"}, + }, + abstract_axis_sizes=axes_sizes, + sample={"i_inner": 2, "j_inner": 16, "i_unroll": None, "j_vectorize": None}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_mlir_sample", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK:// -----// IR Dump Before transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>) +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>) +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +#CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_3 "C/k" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_5 "C/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_7 "C/j" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i0" : !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_8) : (!transform.any_op) -> () +#CHECK-NEXT: transform.loop.unroll %loops_9 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %2 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %2 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %2 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-EMPTY: +#CHECK-NEXT: // -----// IR Dump After transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<4x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x16xf32> +#CHECK-NEXT: %0 = ub.poison : f32 +#CHECK-NEXT: %c16 = arith.constant 16 : index +#CHECK-NEXT: %c2 = arith.constant 2 : index +#CHECK-NEXT: %c512 = arith.constant 512 : index +#CHECK-NEXT: %c32 = arith.constant 32 : index +#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: %c0 = arith.constant 0 : index +#CHECK-NEXT: %c4 = arith.constant 4 : index +#CHECK-NEXT: %c1 = arith.constant 1 : index +#CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"./j"} +#CHECK-NEXT: } {"./i"} +#CHECK-NEXT: scf.for %arg3 = %c0 to %c512 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg0[0, %arg3] [4, 1] [1, 1] : memref<4x512xf32> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_1 = memref.subview %arg1[%arg3, 0] [1, 32] [1, 1] : memref<512x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, 0] [4, 32] [1, 1] : memref<4x32xf32> to memref<4x32xf32, strided<[32, 1]>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c4 step %c2 { +#CHECK-NEXT: %subview_3 = memref.subview %subview[%arg4, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_4 = memref.subview %subview_2[%arg4, 0] [2, 32] [1, 1] : memref<4x32xf32, strided<[32, 1]>> to memref<2x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c32 step %c16 { +#CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg5] [1, 16] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_6 = memref.subview %subview_4[0, %arg5] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_7 = memref.subview %subview_3[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_8 = memref.subview %subview_6[%c0, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_7[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_8[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<16xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<16xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_8[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_9 = memref.subview %subview_3[%c1, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_6[%c1, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_9[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_10[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x16xf32, strided<[32, 1], offset: ?>>, vector<1x16xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<16xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<16xf32> from vector<1x16xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<16xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<16xf32> into vector<1x16xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_10[%c0, %c0] {in_bounds = [true, true]} : vector<1x16xf32>, memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/j"} +#CHECK-NEXT: } {"C/i"} +#CHECK-NEXT: } {"C/k"} +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-EMPTY: +#CHECK-NEXT: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 4x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 4x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +#CHECK-EMPTY: +#CHECK-NEXT: CODE: 0 + diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py new file mode 100644 index 00000000..0e21fa5c --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_mlir_split.py @@ -0,0 +1,241 @@ +# RUN: python %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 16, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + abstract_axis_sizes=axes_sizes, + spec={ + "j": {}, + "k": {}, + "j#jDDR": {}, + "i[:4]": { + "i#iR1": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + "i[4:]": { + "i#iR2": {"unroll": None}, + "j#jR": {"vectorize": None}, + }, + }, + sample={"jDDR": 16, "jR": 4, "iR1": 2, "iR2": 4}, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_mlir_split", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK: // -----// IR Dump Before transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<16x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<16x32xf32>) +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) { +#CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { +#CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op +#CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %1 tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_3 "C/j" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_5 "C/k" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [0, 4, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_7 "C/j0" : !transform.any_op +#CHECK-NEXT: %2 = transform.structured.split %tiled_linalg_op_6 after 4 {dimension = 0 : i64} : !transform.any_op +#CHECK-NEXT: %3:2 = transform.split_handle %2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %3#0 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_9 "C/i[0]/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_10, %loops_11 = transform.structured.tile_using_for %tiled_linalg_op_8 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_11 "C/i[0]/i0" : !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_10) : (!transform.any_op) -> () +#CHECK-NEXT: transform.loop.unroll %loops_11 {factor = 2 : i64} : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_12, %loops_13 = transform.structured.tile_using_for %3#1 tile_sizes [4, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_13 "C/i[1]/i" : !transform.any_op +#CHECK-NEXT: %tiled_linalg_op_14, %loops_15 = transform.structured.tile_using_for %tiled_linalg_op_12 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +#CHECK-NEXT: transform.annotate %loops_15 "C/i[1]/i0" : !transform.any_op +#CHECK-NEXT: transform.include @_vecto failures(suppress) (%tiled_linalg_op_14) : (!transform.any_op) -> () +#CHECK-NEXT: transform.loop.unroll %loops_15 {factor = 4 : i64} : !transform.any_op +#CHECK-NEXT: %4 = transform.get_parent_op %loops_3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %4 { +#CHECK-NEXT: transform.apply_patterns.vector.reduction_to_contract +#CHECK-NEXT: transform.apply_patterns.vector.transfer_permutation_patterns +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.apply_patterns to %4 { +#CHECK-NEXT: transform.apply_patterns.vector.lower_outerproduct +#CHECK-NEXT: transform.apply_patterns.vector.lower_contraction +#CHECK-NEXT: } : !transform.any_op +#CHECK-NEXT: transform.yield +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-EMPTY: +#CHECK-NEXT: // -----// IR Dump After transform //----- // +#CHECK-NEXT: module attributes {transform.with_named_sequence} { +#CHECK-NEXT: func.func @matmul(%arg0: memref<16x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias}, %arg2: memref<16x32xf32> {llvm.noalias}) { +#CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : vector<1x4xf32> +#CHECK-NEXT: %c3 = arith.constant 3 : index +#CHECK-NEXT: %c12 = arith.constant 12 : index +#CHECK-NEXT: %0 = ub.poison : f32 +#CHECK-NEXT: %c2 = arith.constant 2 : index +#CHECK-NEXT: %c4 = arith.constant 4 : index +#CHECK-NEXT: %c512 = arith.constant 512 : index +#CHECK-NEXT: %c32 = arith.constant 32 : index +#CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +#CHECK-NEXT: %c0 = arith.constant 0 : index +#CHECK-NEXT: %c16 = arith.constant 16 : index +#CHECK-NEXT: %c1 = arith.constant 1 : index +#CHECK-NEXT: scf.for %arg3 = %c0 to %c16 step %c1 { +#CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<16x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c32 step %c1 { +#CHECK-NEXT: %subview_1 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst_0 : f32) outs(%subview_1 : memref<1x1xf32, strided<[32, 1], offset: ?>>) +#CHECK-NEXT: } {"./j"} +#CHECK-NEXT: } {"./i"} +#CHECK-NEXT: scf.for %arg3 = %c0 to %c32 step %c16 { +#CHECK-NEXT: %subview = memref.subview %arg0[0, 0] [16, 512] [1, 1] : memref<16x512xf32> to memref<16x512xf32, strided<[512, 1]>> +#CHECK-NEXT: %subview_1 = memref.subview %arg1[0, %arg3] [512, 16] [1, 1] : memref<512x32xf32> to memref<512x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_2 = memref.subview %arg2[0, %arg3] [16, 16] [1, 1] : memref<16x32xf32> to memref<16x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg4 = %c0 to %c512 step %c1 { +#CHECK-NEXT: %subview_3 = memref.subview %subview[0, %arg4] [16, 1] [1, 1] : memref<16x512xf32, strided<[512, 1]>> to memref<16x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_4 = memref.subview %subview_1[%arg4, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg5 = %c0 to %c16 step %c4 { +#CHECK-NEXT: %subview_5 = memref.subview %subview_4[0, %arg5] [1, 4] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_6 = memref.subview %subview_2[0, %arg5] [16, 4] [1, 1] : memref<16x16xf32, strided<[32, 1], offset: ?>> to memref<16x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_7 = memref.subview %subview_3[0, 0] [4, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_8 = memref.subview %subview_6[0, 0] [4, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<4x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c4 step %c2 { +#CHECK-NEXT: %subview_11 = memref.subview %subview_7[%arg6, 0] [2, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_8[%arg6, 0] [2, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<2x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_13 = memref.subview %subview_11[%c0, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_12[%c0, 0] [1, 4] [1, 1] : memref<2x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_15 = memref.subview %subview_11[%c1, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_12[%c1, 0] [1, 4] [1, 1] : memref<2x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[0]/i"} +#CHECK-NEXT: %subview_9 = memref.subview %subview_3[4, 0] [12, 1] [1, 1] : memref<16x1xf32, strided<[512, 1], offset: ?>> to memref<12x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_10 = memref.subview %subview_6[4, 0] [12, 4] [1, 1] : memref<16x4xf32, strided<[32, 1], offset: ?>> to memref<12x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: scf.for %arg6 = %c0 to %c12 step %c4 { +#CHECK-NEXT: %subview_11 = memref.subview %subview_9[%arg6, 0] [4, 1] [1, 1] : memref<12x1xf32, strided<[512, 1], offset: ?>> to memref<4x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_12 = memref.subview %subview_10[%arg6, 0] [4, 4] [1, 1] : memref<12x4xf32, strided<[32, 1], offset: ?>> to memref<4x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_13 = memref.subview %subview_11[%c0, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_14 = memref.subview %subview_12[%c0, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %1 = vector.transfer_read %subview_13[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %2 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %3 = vector.transfer_read %subview_14[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %4 = vector.extract %2[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %5 = vector.extract %1[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %6 = vector.broadcast %5 : f32 to vector<4xf32> +#CHECK-NEXT: %7 = vector.extract %3[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %8 = vector.fma %6, %4, %7 : vector<4xf32> +#CHECK-NEXT: %9 = vector.insert %8, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %9, %subview_14[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_15 = memref.subview %subview_11[%c1, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_16 = memref.subview %subview_12[%c1, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %10 = vector.transfer_read %subview_15[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %11 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %12 = vector.transfer_read %subview_16[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %13 = vector.extract %11[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %14 = vector.extract %10[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %15 = vector.broadcast %14 : f32 to vector<4xf32> +#CHECK-NEXT: %16 = vector.extract %12[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %17 = vector.fma %15, %13, %16 : vector<4xf32> +#CHECK-NEXT: %18 = vector.insert %17, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %18, %subview_16[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_17 = memref.subview %subview_11[%c2, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_18 = memref.subview %subview_12[%c2, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %19 = vector.transfer_read %subview_17[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %20 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %21 = vector.transfer_read %subview_18[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %22 = vector.extract %20[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %23 = vector.extract %19[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %24 = vector.broadcast %23 : f32 to vector<4xf32> +#CHECK-NEXT: %25 = vector.extract %21[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %26 = vector.fma %24, %22, %25 : vector<4xf32> +#CHECK-NEXT: %27 = vector.insert %26, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %27, %subview_18[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %subview_19 = memref.subview %subview_11[%c3, 0] [1, 1] [1, 1] : memref<4x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>> +#CHECK-NEXT: %subview_20 = memref.subview %subview_12[%c3, 0] [1, 4] [1, 1] : memref<4x4xf32, strided<[32, 1], offset: ?>> to memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: %28 = vector.transfer_read %subview_19[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x1xf32, strided<[512, 1], offset: ?>>, vector<1x1xf32> +#CHECK-NEXT: %29 = vector.transfer_read %subview_5[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %30 = vector.transfer_read %subview_20[%c0, %c0], %0 {in_bounds = [true, true]} : memref<1x4xf32, strided<[32, 1], offset: ?>>, vector<1x4xf32> +#CHECK-NEXT: %31 = vector.extract %29[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %32 = vector.extract %28[0, 0] : f32 from vector<1x1xf32> +#CHECK-NEXT: %33 = vector.broadcast %32 : f32 to vector<4xf32> +#CHECK-NEXT: %34 = vector.extract %30[0] : vector<4xf32> from vector<1x4xf32> +#CHECK-NEXT: %35 = vector.fma %33, %31, %34 : vector<4xf32> +#CHECK-NEXT: %36 = vector.insert %35, %cst [0] : vector<4xf32> into vector<1x4xf32> +#CHECK-NEXT: vector.transfer_write %36, %subview_20[%c0, %c0] {in_bounds = [true, true]} : vector<1x4xf32>, memref<1x4xf32, strided<[32, 1], offset: ?>> +#CHECK-NEXT: } {"C/i[1]/i"} +#CHECK-NEXT: } {"C/j0"} +#CHECK-NEXT: } {"C/k"} +#CHECK-NEXT: } {"C/j"} +#CHECK-NEXT: return +#CHECK-NEXT: } +#CHECK-NEXT: } +#CHECK-EMPTY: +#CHECK-NEXT: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 16x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 16x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [16x512xfloat32, 512x32xfloat32] -> [16x32xfloat32] +#CHECK-EMPTY: +#CHECK-NEXT: CODE: 0 + diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py new file mode 100644 index 00000000..cbe63694 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_goto.py @@ -0,0 +1,207 @@ +# RUN: python %s 2>&1 | filecheck %s +# REQUIRES: module_tvm + +import xtc.graphs.xtc.op as O +from xtc.backends.tvm import Backend +from xtc.schedules.descript_extend import descript_extend_scheduler + +I, J, K, dtype = 512, 512, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() +axes_sizes = {"i": I, "j": J, "k": K} +descript_extend_scheduler( + scheduler=sch, + node_name="C", + abstract_axis=["i", "j", "k"], + abstract_axis_sizes=axes_sizes, + abstract_matrix=["A", "B", "C"], + spec={ + "j": {"parallelize": "par"}, + "k": {}, + "i": {}, + "B": {"bufferize": "pack_B"}, + "A": {"bufferize": "pack_A"}, + "j#jL3": {}, + "i#iL2": {}, + "k#kL1": {"unroll": "k_unroll"}, + "i#iR": {"unroll": None}, "j#jR": {"vectorize": None}, + }, + sample={ + "par": None, + "jL3": 36, + "iL2": 128, + "kL1": 16, + "k_unroll": 2, + "iR": 2, + "jR": 6, + "pack_B": None, + "pack_A": None, + }, +) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_tvm_goto", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK: graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 512x512xfloat32 +#CHECK-NEXT: - %1 : 512x512xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 512x512xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [512x512xfloat32, 512x512xfloat32] -> [512x512xfloat32] +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for i, j in T.grid(512, 512): +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: C_1[i * 512 + j] = T.float32(0.0) +#CHECK-NEXT: for k in range(512): +#CHECK-NEXT: cse_var_2: T.int32 = i * 512 +#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + j +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[cse_var_2 + k] * _1_1[k * 512 + j] +#CHECK-NEXT:O = obj['C'] +#CHECK-NEXT:i, j, = O.op.axis +#CHECK-NEXT:k, = O.op.reduce_axis +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=36) +#CHECK-NEXT:i, i0 = sch[O].split(i, factor=128) +#CHECK-NEXT:k, k0 = sch[O].split(k, factor=16) +#CHECK-NEXT:k0, __u_k0 = sch[O].split(k0, factor=2) +#CHECK-NEXT:i0, i1 = sch[O].split(i0, factor=2) +#CHECK-NEXT:j0, j1 = sch[O].split(j0, factor=6) +#CHECK-NEXT:j1, __v_j1 = sch[O].split(j1, factor=2) +#CHECK-NEXT:sch[O].reorder(j, k, i, j0, i0, k0, __u_k0, i1, j1, __v_j1) +#CHECK-NEXT:sch[O].unroll(__u_k0) +#CHECK-NEXT:sch[O].unroll(i1) +#CHECK-NEXT:sch[O].unroll(j1) +#CHECK-NEXT:sch[O].vectorize(__v_j1) +#CHECK-NEXT:sch[O].parallel(j) +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((512, 512), "float32"), _1: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for j_outer in T.parallel(15): +#CHECK-NEXT: C_1 = T.Buffer((262144,), data=C.data) +#CHECK-NEXT: for i_outer_init, j_inner_outer_init, i_inner_outer_init in T.grid(4, 6, 64): +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + 1) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 2 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 127): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 4:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 4 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 512:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 512 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer_init * 3 + 1) // 2 < 128): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 514 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer_init * 3 // 2 < 127): +#CHECK-NEXT: C_1[i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516:i_outer_init * 65536 + i_inner_outer_init * 1024 + j_outer * 36 + j_inner_outer_init * 6 + 516 + 2] = T.Broadcast(T.float32(0.0), 2) +#CHECK-NEXT: for k_outer, i_outer, j_inner_outer, i_inner_outer, k_inner_outer in T.grid(32, 4, 6, 64, 8): +#CHECK-NEXT: _0_1 = T.Buffer((262144,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((262144,), data=_1.data) +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_4: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_3: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_2: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_1: T.int32 = cse_var_2 + cse_var_4 + cse_var_3 +#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 2] = C_1[cse_var_1:cse_var_1 + 2] + T.Broadcast(_0_1[cse_var_2 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_4 + cse_var_3:k_outer * 8192 + k_inner_outer * 1024 + cse_var_4 + cse_var_3 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_8: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_7: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_6: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_5: T.int32 = cse_var_6 + cse_var_8 + cse_var_7 + 2 +#CHECK-NEXT: C_1[cse_var_5:cse_var_5 + 2] = C_1[cse_var_5:cse_var_5 + 2] + T.Broadcast(_0_1[cse_var_6 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_8 + cse_var_7 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_8 + cse_var_7 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_12: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_11: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_10: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_9: T.int32 = cse_var_10 + cse_var_12 + cse_var_11 + 4 +#CHECK-NEXT: C_1[cse_var_9:cse_var_9 + 2] = C_1[cse_var_9:cse_var_9 + 2] + T.Broadcast(_0_1[cse_var_10 + k_outer * 16 + k_inner_outer * 2], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_12 + cse_var_11 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_16: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_15: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_14: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_13: T.int32 = cse_var_14 + cse_var_16 + cse_var_15 + 512 +#CHECK-NEXT: C_1[cse_var_13:cse_var_13 + 2] = C_1[cse_var_13:cse_var_13 + 2] + T.Broadcast(_0_1[cse_var_14 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_16 + cse_var_15:k_outer * 8192 + k_inner_outer * 1024 + cse_var_16 + cse_var_15 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_20: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_19: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_18: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_17: T.int32 = cse_var_18 + cse_var_20 + cse_var_19 + 514 +#CHECK-NEXT: C_1[cse_var_17:cse_var_17 + 2] = C_1[cse_var_17:cse_var_17 + 2] + T.Broadcast(_0_1[cse_var_18 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_20 + cse_var_19 + 2:k_outer * 8192 + k_inner_outer * 1024 + cse_var_20 + cse_var_19 + 2 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_24: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_23: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_22: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_21: T.int32 = cse_var_22 + cse_var_24 + cse_var_23 + 516 +#CHECK-NEXT: C_1[cse_var_21:cse_var_21 + 2] = C_1[cse_var_21:cse_var_21 + 2] + T.Broadcast(_0_1[cse_var_22 + k_outer * 16 + k_inner_outer * 2 + 512], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 4:k_outer * 8192 + k_inner_outer * 1024 + cse_var_24 + cse_var_23 + 4 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_28: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_27: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_26: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_25: T.int32 = cse_var_26 + cse_var_28 + cse_var_27 +#CHECK-NEXT: C_1[cse_var_25:cse_var_25 + 2] = C_1[cse_var_25:cse_var_25 + 2] + T.Broadcast(_0_1[cse_var_26 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_28 + cse_var_27 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_28 + cse_var_27 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_32: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_31: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_30: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_29: T.int32 = cse_var_30 + cse_var_32 + cse_var_31 + 2 +#CHECK-NEXT: C_1[cse_var_29:cse_var_29 + 2] = C_1[cse_var_29:cse_var_29 + 2] + T.Broadcast(_0_1[cse_var_30 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_32 + cse_var_31 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_32 + cse_var_31 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_36: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_35: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_34: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_33: T.int32 = cse_var_34 + cse_var_36 + cse_var_35 + 4 +#CHECK-NEXT: C_1[cse_var_33:cse_var_33 + 2] = C_1[cse_var_33:cse_var_33 + 2] + T.Broadcast(_0_1[cse_var_34 + k_outer * 16 + k_inner_outer * 2 + 1], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_36 + cse_var_35 + 516 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 128): +#CHECK-NEXT: cse_var_40: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_39: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_38: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_37: T.int32 = cse_var_38 + cse_var_40 + cse_var_39 + 512 +#CHECK-NEXT: C_1[cse_var_37:cse_var_37 + 2] = C_1[cse_var_37:cse_var_37 + 2] + T.Broadcast(_0_1[cse_var_38 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_40 + cse_var_39 + 512:k_outer * 8192 + k_inner_outer * 1024 + cse_var_40 + cse_var_39 + 512 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + (j_inner_outer * 3 + 1) // 2 < 128): +#CHECK-NEXT: cse_var_44: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_43: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_42: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_41: T.int32 = cse_var_42 + cse_var_44 + cse_var_43 + 514 +#CHECK-NEXT: C_1[cse_var_41:cse_var_41 + 2] = C_1[cse_var_41:cse_var_41 + 2] + T.Broadcast(_0_1[cse_var_42 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_44 + cse_var_43 + 514:k_outer * 8192 + k_inner_outer * 1024 + cse_var_44 + cse_var_43 + 514 + 2] +#CHECK-NEXT: if T.likely(j_outer * 9 + j_inner_outer * 3 // 2 < 127): +#CHECK-NEXT: cse_var_48: T.int32 = j_outer * 36 +#CHECK-NEXT: cse_var_47: T.int32 = j_inner_outer * 6 +#CHECK-NEXT: cse_var_46: T.int32 = i_outer * 65536 + i_inner_outer * 1024 +#CHECK-NEXT: cse_var_45: T.int32 = cse_var_46 + cse_var_48 + cse_var_47 + 516 +#CHECK-NEXT: C_1[cse_var_45:cse_var_45 + 2] = C_1[cse_var_45:cse_var_45 + 2] + T.Broadcast(_0_1[cse_var_46 + k_outer * 16 + k_inner_outer * 2 + 513], 2) * _1_1[k_outer * 8192 + k_inner_outer * 1024 + cse_var_48 + cse_var_47 + 516:k_outer * 8192 + k_inner_outer * 1024 + cse_var_48 + cse_var_47 + 516 + 2] +#CHECK:CODE: 0 diff --git a/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py new file mode 100644 index 00000000..0429b060 --- /dev/null +++ b/tests/filecheck/schedules/test_matmul_descript_extend_tvm_strategy.py @@ -0,0 +1,108 @@ +# RUN: python -O %s 2>&1 | filecheck %s + +import xtc.graphs.xtc.op as O +from xtc.backends.tvm import Backend +# from xtc.itf.search import strategy +# from xtc.schedules.descript import descript_tree_scheduler +from xtc.search.strategies import Strategy_Descript as Strategy + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph +print(graph) + +impl = Backend(graph, always_vectorize=False, no_alias=True) + +sch = impl.get_scheduler() + +spec = { + "k": {}, + "i": {}, + "j": {}, + "i#i_inner": {"unroll": "i_unroll"}, + "j#j_inner": {"vectorize": "j_vectorize"}, +} + +sample = {"i_inner": 2, "j_inner": 16, "i_unroll": 2, "j_vectorize": None, "order_R": ["j", "i"]} + +strategy = Strategy(graph, spec) + +strategy.generate(sch, sample) + +sched = sch.schedule() + +comp = impl.get_compiler( + shared_lib=True, + dump_file="matmul_descript_extend_tvm_strategy", + print_source_ir=True, + print_transformed_ir=True, +) +module = comp.compile(sched) +executor = module.get_executor(validate=True) +res = executor.execute() +print(f"CODE: {res}") + +#CHECK:graph: +#CHECK-NEXT: name: matmul +#CHECK-NEXT: inputs: +#CHECK-NEXT: - %0 : 4x512xfloat32 +#CHECK-NEXT: - %1 : 512x32xfloat32 +#CHECK-NEXT: outputs: +#CHECK-NEXT: - %2 : 4x32xfloat32 +#CHECK-NEXT: nodes: +#CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32] +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((4, 512), "float32"), _1: T.Buffer((512, 32), "float32"), C: T.Buffer((4, 32), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: for i, j in T.grid(4, 32): +#CHECK-NEXT: C_1 = T.Buffer((128,), data=C.data) +#CHECK-NEXT: C_1[i * 32 + j] = T.float32(0.0) +#CHECK-NEXT: for k in range(512): +#CHECK-NEXT: cse_var_1: T.int32 = i * 32 + j +#CHECK-NEXT: _0_1 = T.Buffer((2048,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((16384,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_1] = C_1[cse_var_1] + _0_1[i * 512 + k] * _1_1[k * 32 + j] +#CHECK-NEXT:O = obj['C'] +#CHECK-NEXT:i, j, = O.op.axis +#CHECK-NEXT:k, = O.op.reduce_axis +#CHECK-NEXT:i, i0 = sch[O].split(i, factor=2) +#CHECK-NEXT:j, j0 = sch[O].split(j, factor=16) +#CHECK-NEXT:sch[O].reorder(k, i, j, i0, j0) +#CHECK-NEXT:sch[O].unroll(i0) +#CHECK-NEXT:sch[O].vectorize(j0) +#CHECK-EMPTY: +#CHECK-NEXT:# from tvm.script import ir as I +#CHECK-NEXT:# from tvm.script import tir as T +#CHECK-EMPTY: +#CHECK-NEXT:@I.ir_module +#CHECK-NEXT:class Module: +#CHECK-NEXT: @T.prim_func +#CHECK-NEXT: def main(_0: T.Buffer((4, 512), "float32"), _1: T.Buffer((512, 32), "float32"), C: T.Buffer((4, 32), "float32")): +#CHECK-NEXT: T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)}) +#CHECK-NEXT: C_1 = T.Buffer((128,), data=C.data) +#CHECK-NEXT: for i_outer_init, j_outer_init in T.grid(2, 2): +#CHECK-NEXT: cse_var_1: T.int32 = i_outer_init * 64 + j_outer_init * 16 +#CHECK-NEXT: C_1[cse_var_1:cse_var_1 + 16] = T.Broadcast(T.float32(0.0), 16) +#CHECK-NEXT: C_1[cse_var_1 + 32:cse_var_1 + 32 + 16] = T.Broadcast(T.float32(0.0), 16) +#CHECK-NEXT: for k, i_outer, j_outer in T.grid(512, 2, 2): +#CHECK-NEXT: cse_var_6: T.int32 = j_outer * 16 +#CHECK-NEXT: cse_var_5: T.int32 = i_outer * 1024 + k +#CHECK-NEXT: cse_var_4: T.int32 = k * 32 + cse_var_6 +#CHECK-NEXT: cse_var_3: T.int32 = i_outer * 64 + cse_var_6 +#CHECK-NEXT: cse_var_2: T.int32 = cse_var_3 + 32 +#CHECK-NEXT: _0_1 = T.Buffer((2048,), data=_0.data) +#CHECK-NEXT: _1_1 = T.Buffer((16384,), data=_1.data) +#CHECK-NEXT: C_1[cse_var_3:cse_var_3 + 16] = C_1[cse_var_3:cse_var_3 + 16] + T.Broadcast(_0_1[cse_var_5], 16) * _1_1[cse_var_4:cse_var_4 + 16] +#CHECK-NEXT: C_1[cse_var_2:cse_var_2 + 16] = C_1[cse_var_2:cse_var_2 + 16] + T.Broadcast(_0_1[cse_var_5 + 512], 16) * _1_1[cse_var_4:cse_var_4 + 16] +#CHECK-NEXT:CODE: 0 diff --git a/tests/filecheck/search/test_matmul_descript_3axes.py b/tests/filecheck/search/test_matmul_descript_3axes.py new file mode 100644 index 00000000..b50d5b37 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_3axes.py @@ -0,0 +1,23 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy 3-axis on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph, backend="tvm") +spec = { + "j": {}, + "k": {}, + "i": {}, + "j#jR": {}, + "k#kR": {}, + "i#iR": {}, +} +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK:['iR || {21}', 'jR || {32}', 'kR || {12}'] diff --git a/tests/filecheck/search/test_matmul_descript_goto.py b/tests/filecheck/search/test_matmul_descript_goto.py new file mode 100644 index 00000000..0c45e76c --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_goto.py @@ -0,0 +1,27 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "j": {"parallelize": "j_parallel"}, + "k": {}, + "i": {}, + "pack": ("pack_B", 1, True), + "pack": ("pack_A", 0, True), + "j#jL3": {}, + "i#iL2": {}, + "k#kL1": {"unroll": "k_unroll"}, + "i#iR": {"unroll": None}, "j#jR": {"vectorize": "j_vectorise"} +} +constraint = ["iR * jR <= 56"] +strategy = Strategy(graph, spec, constraints=constraint, initialize=False) + +print(strategy._constraints) + +# CHECK: ['iL2 || {21}', 'iR * jR <= 56', 'iR || {21, iL2}', 'jL3 || {32}', 'jR || {32, jL3}', 'j_parallel in {0, 1}', 'j_vectorise in {0, 1}', 'kL1 || {12}', 'k_unroll || kL1', 'pack_A in {0, 1}'] diff --git a/tests/filecheck/search/test_matmul_descript_simple.py b/tests/filecheck/search/test_matmul_descript_simple.py new file mode 100644 index 00000000..5afcb94b --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_simple.py @@ -0,0 +1,24 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy +from xtc.schedules.descript_extend import DescriptExtend +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "k": {}, + "i": {}, + "j": {}, + "i#i1": {}, + "j#j1": {}, + "j#j2": {} +} + +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {32, j1}'] diff --git a/tests/filecheck/search/test_matmul_descript_split.py b/tests/filecheck/search/test_matmul_descript_split.py new file mode 100644 index 00000000..881111ff --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_split.py @@ -0,0 +1,31 @@ +# RUN: python %s 2>&1 | filecheck %s +""" +Test splits on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "j": {}, + "k": {}, + "i": {}, + "i#iL3": {}, + "i#7": {}, + "j#jDDR": {}, + "i[:5]": { + "i#iR1": {"unroll": None}, + "j#jR1": {"parallelize": None}, + }, + "i[5:]": { + "i#iR2": {"unroll": None}, + "j#jR2": {"parallelize": None}, + }, +} +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['iL3 || {21}', 'iR1 || {21, iL3, 7}', 'iR2 || {21, iL3, 7}', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}'] diff --git a/tests/filecheck/search/test_matmul_descript_split_in_split.py b/tests/filecheck/search/test_matmul_descript_split_in_split.py new file mode 100644 index 00000000..40569f97 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_split_in_split.py @@ -0,0 +1,34 @@ +# RUN: python -O %s 2>&1 | filecheck %s +""" +Test splits on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = { + "j": {}, + "k": {}, + "i": {}, + "i#iL2": {}, + "j#jDDR": {}, + "i[:6]": { + "i#3": {}, + "i[:2:]": { + "i#iR1": {"unroll": None}, + "j#jR1": {"vectorize": None}, + }, + "i[:iS:]": {"i#iR3": {}, "j#jR3": {}}, + }, + "i[6:]": { + "i#iR2": {"unroll": None}, + "j#jR2": {"vectorize": None}, + }, +} +strategy = Strategy(graph, spec, initialize=False) + +print(strategy._constraints) + +# CHECK: ['iL2 || {21}', 'iR1 || {21, iL2, 3}', 'iR2 || {21, iL2}', 'iR3 || {21, iL2, 3}', 'iS + 2 == 3', 'i_1_ + 6 == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}', 'jR3 || {32, jDDR}'] diff --git a/tests/filecheck/search/test_matmul_descript_yaml_goto.py b/tests/filecheck/search/test_matmul_descript_yaml_goto.py new file mode 100644 index 00000000..ebe57ee1 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_goto.py @@ -0,0 +1,61 @@ +# RUN: python -O %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +import xtc.graphs.xtc.op as O + +graph = utils.get_graph_matmul() +I, J, K, dtype = 1024, 1024, 1024, "float32" + +a = O.tensor((I, K), dtype) +b = O.tensor((K, J), dtype) + +with O.graph(name="matmul") as gb: + O.matmul(a, b) +graph = gb.graph +backend = utils.get_backend(graph, "tvm") + +spec = """ + j: + k: + B: bufferize + i: + A: bufferize + j#nc: + i#mc: + k#kc: unroll=kr + i#mr: unroll full + j#nr: vectorize full +""" + +nb_registers = 32 +nb_fma = 2 +fma_latency = 4 +ilp = nb_fma*fma_latency +vector_size = 16 +elt_size = 4 +reorder_buffer = 256 +nb_words_L1 = 32*1024//elt_size +nb_words_L2 = 1024*1024//elt_size +nb_words_L3 = 36*1024*1024//elt_size + +constraints = [ +f"1 + nvr + nvr * mr <= {nb_registers}", +f"nr == {vector_size} * nvr", +f"nvr * mr >= {ilp}", +f"nvr * mr * kr <= {reorder_buffer}", +f"kc * nr <= {nb_words_L1}", +f"kc * mc <= {nb_words_L2}", +f"kc * nc <= {nb_words_L3}", +] +strategy = Strategy(graph, spec, constraints=constraints, partial_tiles=True, partial_unrolls=True, initialize=False) + +print(strategy._constraints) +print(len(list(strategy.sample(100)))) + +# CHECK: ['1 + nvr + nvr * mr <= 32', 'kc * mc <= 262144', 'kc * nc <= 9437184', 'kc * nr <= 8192', 'kc <= 1024', 'kr <= kc', 'mc <= 1024', 'mr || {1024, mc}', 'nc <= 1024', 'nr == 16 * nvr', 'nr || {1024, nc}', 'nvr * mr * kr <= 256', 'nvr * mr >= 8'] +#CHECK-NEXT: 100 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_simple.py b/tests/filecheck/search/test_matmul_descript_yaml_simple.py new file mode 100644 index 00000000..e2fa5e17 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_simple.py @@ -0,0 +1,25 @@ +# RUN: python -O %s 2>&1 | filecheck %s +""" +Test strategy Goto on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = """ + k: + i: + j: + i#i1: + j#j1: + j#j2: +""" +strategy = Strategy(graph, spec) + +print(strategy._constraints) +print(len(list(strategy.sample(100)))) + +# CHECK: ['i1 || {21}', 'j1 || {32}', 'j2 || {32, j1}'] +# CHECK-NEXT: 84 diff --git a/tests/filecheck/search/test_matmul_descript_yaml_split.py b/tests/filecheck/search/test_matmul_descript_yaml_split.py new file mode 100644 index 00000000..d07957c5 --- /dev/null +++ b/tests/filecheck/search/test_matmul_descript_yaml_split.py @@ -0,0 +1,32 @@ +# RUN: python %s -O 2>&1 | filecheck %s +""" +Test splits on matmul +""" + +import utils +from xtc.search.strategies import Strategy_Descript as Strategy + +graph = utils.get_graph_matmul() +backend = utils.get_backend(graph) +spec = """ + j: + k: + i: + i#iL3: + i#iL2: + j#jDDR: + i[:iS]: + i#iR1: unroll + j#jR1: vectorize + k#SR: + i[iS:]: + i#iR2: unroll + j#jR2: unroll +""" +strategy = Strategy(graph, spec) + +print(strategy._constraints) +print(len(list(strategy.sample(100)))) + +# CHECK: ['SR || {12}', 'iL2 || {21, iL3}', 'iL3 || {21}', 'iR1 || {21, iL3, iL2}', 'iR2 || {21, iL3, iL2}', 'iS <= iL2', 'i_1_ + iS == iL2', 'i_1_ <= iL2', 'jDDR || {32}', 'jR1 || {32, jDDR}', 'jR2 || {32, jDDR}'] +# CHECK-NEXT: 100