diff --git a/src/xtc/backends/jir/JIRScheduler.py b/src/xtc/backends/jir/JIRScheduler.py index ac40e930..dde8312c 100644 --- a/src/xtc/backends/jir/JIRScheduler.py +++ b/src/xtc/backends/jir/JIRScheduler.py @@ -7,6 +7,7 @@ import xtc.itf as itf from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest import xtc.backends.jir as backend __all__ = [ @@ -347,3 +348,32 @@ def distributed_buffer_at( def get_schedule_str(self) -> str: return str(JIRSchedule(scheduler=self)) + + @override + def get_loop_nest(self) -> LoopNest: + transformer = self._transformer + dims = list(transformer.dims.keys()) + + loop_nest = LoopNest(abstract_dims=dims) + root_node = loop_nest.build_root_node(self._backend.payload_name) + + # Build tiles mapping + for axis, axis_tiles in transformer.tiles.items(): + for tile_name, size in axis_tiles.items(): + root_node.tiles[axis][tile_name] = size + + # Build interchange + root_node.interchange = ( + list(transformer.order) if transformer.order else dims[:] + ) + + # Build vectorization list + root_node.vectorize = list(transformer.vectorized) + + # Build parallelization list + root_node.parallelize = list(transformer.parallelized) + + # Build unroll mapping + root_node.unroll = dict(transformer.unrolled) + + return loop_nest diff --git a/src/xtc/backends/mlir/MlirScheduler.py b/src/xtc/backends/mlir/MlirScheduler.py index 605477b3..cdf382b5 100644 --- a/src/xtc/backends/mlir/MlirScheduler.py +++ b/src/xtc/backends/mlir/MlirScheduler.py @@ -5,6 +5,7 @@ from typing_extensions import override from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest, LoopNestNode, LoopInfo, SplitOrigin import xtc.itf as itf import xtc.backends.mlir as backend @@ -203,6 +204,55 @@ def distributed_buffer_at( axis, input_idx, memory_axes, root=root ) + @override + def get_loop_nest(self) -> LoopNest: + node_sched = self._current_scheduler + dims = node_sched.dims[:] + + loop_nest = LoopNest(abstract_dims=dims) + root_node = loop_nest.build_root_node(node_sched.node_name) + + # Assign splits to root_node first + for axis, axis_splits in node_sched.splits.items(): + root_node.splits[axis] = dict(axis_splits) + + # Build mapper to get splits_info + mapper = LoopInfo.build_from_node(root_node) + + def populate_node(node: LoopNestNode, perm: list[str]) -> None: + """Populate node with data for loops in its permutation.""" + node.interchange = list(perm) + perm_set = set(perm) + for axis, axis_tiles in node_sched.tiles.items(): + for tile_name, size in axis_tiles.items(): + if tile_name in perm_set: + if axis not in node.tiles: + node.tiles[axis] = {} + node.tiles[axis][tile_name] = size + node.vectorize = [v for v in node_sched.vectorization if v in perm_set] + node.parallelize = [p for p in node_sched.parallelization if p in perm_set] + node.unroll = { + k: v for k, v in node_sched.unrolling.items() if k in perm_set + } + + # Process each root in permutation + for root, perm in node_sched.permutation.items(): + if root in mapper.splits_info: + # This root is a split - create child node + axis, start, end = mapper.splits_info[root] + child = LoopNestNode( + root=root, + tiles={d: {} for d in dims}, + split_origin=SplitOrigin(axis=axis, start=start, end=end), + ) + populate_node(child, perm) + root_node.add_child(child) + else: + # This is the main root + populate_node(root_node, perm) + + return loop_nest + class MlirSchedule(itf.schd.Schedule): def __init__( diff --git a/src/xtc/backends/tvm/TVMScheduler.py b/src/xtc/backends/tvm/TVMScheduler.py index 8e51249b..59c58ff0 100644 --- a/src/xtc/backends/tvm/TVMScheduler.py +++ b/src/xtc/backends/tvm/TVMScheduler.py @@ -12,6 +12,7 @@ from xtc.utils.math import pow2divisor from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest import xtc.backends.tvm as backend import xtc.itf as itf @@ -511,6 +512,39 @@ def _get_plain_schedule(self) -> TVMPlainSchedule: def __str__(self) -> str: return str(self._get_plain_schedule()) + @override + def get_loop_nest(self) -> LoopNest: + loop_nest = LoopNest(abstract_dims=self.dims[:]) + root_node = loop_nest.build_root_node(self._op.name or "op") + + # Build tiles mapping + for axis, axis_tiles in self.tiles.items(): + for tile_name, size in axis_tiles.items(): + if tile_name != axis: + root_node.tiles[axis][tile_name] = size + + # Build interchange + root_node.interchange = list(self.permutation) + + # Build vectorization list + root_node.vectorize = list(self.vectorization) + + # Build parallelization list + root_node.parallelize = list(self.parallelization) + + # Build unroll mapping + root_node.unroll = dict(self.unrolling) + + # Build buffer_at mapping + root_node.buffer_at = {axis: None for axis in self.write_caches} + + # Build pack_at mapping + root_node.pack_at = { + axis: (input_idx, None, pad) for axis, input_idx, pad in self.read_buffers + } + + return loop_nest + class TVMSchedule(itf.schd.Schedule): def __init__(self, scheduler: "TVMScheduler", schedule_impl: ScheduleImpl) -> None: diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index 09f9f8fb..2bd93991 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -136,7 +136,7 @@ def build_node_scheduler( descript_scheduler( scheduler=scheduler, node_name=node_name, - abstract_axis=scheduler.backend.dims, + abstract_dims=scheduler.backend.dims, spec=normal_schedule, ) op.attributes.pop("loop.schedule", None) diff --git a/src/xtc/itf/schd/scheduler.py b/src/xtc/itf/schd/scheduler.py index 3ea1febb..f2abc465 100644 --- a/src/xtc/itf/schd/scheduler.py +++ b/src/xtc/itf/schd/scheduler.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from .schedule import Schedule import xtc.itf +from xtc.schedules.loop_nest import LoopNest DEFAULT_ROOT = "." ROOT_SEP = "/" @@ -291,3 +292,19 @@ def distributed_buffer_at( root: the parent split (or the operator's absolute root) """ ... + + @abstractmethod + def get_loop_nest(self) -> LoopNest: + """Return a LoopNest representation of the current schedule. + + This method constructs a LoopNest object that describes the loop + structure resulting from the scheduling transformations applied + so far. The LoopNest can be used for visualization (via pretty_print) + or further analysis. + + Returns: + LoopNest: A tree structure representing the scheduled loop nest, + including tiles, splits, interchange order, and annotations + (vectorization, parallelization, unrolling). + """ + ... diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 63ee7052..50e4de5b 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -2,256 +2,78 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from __future__ import annotations - from typing import Any -from dataclasses import dataclass, field -import re -from typing_extensions import override -from xtc.itf.schd.scheduler import Scheduler, ROOT_SEP, SPLIT_LEFT_SEP, SPLIT_RIGHT_SEP - - -class ScheduleParseError(RuntimeError): - """Raised when schedule parsing fails.""" - - pass - - -class ScheduleInterpretError(RuntimeError): - """Raised when schedule interpretation fails.""" - - pass +from dataclasses import dataclass +from xtc.itf.schd.scheduler import Scheduler, ROOT_SEP, SPLIT_LEFT_SEP, SPLIT_RIGHT_SEP +from .exceptions import ScheduleInterpretError +from .parsing import ( + ScheduleParser, + ScheduleSpec, + SplitDecl, + TileDecl, + AxisDecl, + Annotations, +) +from .loop_nest import LoopNestNode, LoopNest, SplitOrigin -class ScheduleValidationError(RuntimeError): - """Raised when schedule validation fails.""" - pass +def descript_scheduler( + scheduler: Scheduler, + node_name: str, + abstract_dims: list[str], + spec: dict[str, dict[str, Any]], +) -> None: + """Apply a schedule specification to a scheduler. + This is the main entry point for using the descript scheduling DSL. -@dataclass(frozen=True) -class Annotations: - """AST Type : annotations that can be applied to a loop. - - Attributes: - unroll_factor: The unroll factor. None means "unroll fully" (use loop size). - Only meaningful when unroll_specified is True. - unroll_specified: True if unroll was explicitly requested. - vectorize: True if vectorization was requested. - parallelize: True if parallelization was requested. + Args: + scheduler: The scheduler to apply the schedule to. + node_name: The name of the root node to schedule. + abstract_dims: The list of abstract axis names (e.g., ["m", "n", "k"]). + spec: The schedule specification as a nested dict. """ - - unroll_factor: int | None = None - unroll_specified: bool = False - vectorize: bool = False - parallelize: bool = False - - -@dataclass(frozen=True) -class SplitDecl: - """AST Type: a split declaration like 'axis[start:end]'.""" - - axis: str - start: int | None - end: int | None - body: ScheduleSpec - - @override - def __str__(self) -> str: - start_str = "" if self.start is None else str(self.start) - end_str = "" if self.end is None else str(self.end) - decl = f"{self.axis}{SPLIT_LEFT_SEP}{start_str}:{end_str}{SPLIT_RIGHT_SEP}" - return decl - - -@dataclass(frozen=True) -class TileDecl: - """AST Type: a tile declaration like 'axis#size'.""" - - axis: str - size: int - annotations: Annotations - - @override - def __str__(self) -> str: - return f"{self.axis}#{self.size}" - - -@dataclass(frozen=True) -class AxisDecl: - """AST Type: a direct axis reference.""" - - axis: str - annotations: Annotations - - -ScheduleItem = SplitDecl | TileDecl | AxisDecl - - -@dataclass(frozen=True) -class ScheduleSpec: - """AST Type: the complete parsed schedule specification.""" - - items: tuple[ScheduleItem, ...] - - -class ScheduleParser: - """Parses a dict-based schedule specification into an AST.""" - - _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$") - - def __init__(self, abstract_axis: list[str]): - self.abstract_axis = abstract_axis - - def parse(self, spec: dict[str, Any]) -> ScheduleSpec: - """Parse a schedule specification dict into an AST.""" - items: list[ScheduleItem] = [] - - for declaration, value in spec.items(): - item = self._parse_declaration(declaration, value) - items.append(item) - - return ScheduleSpec(items=tuple(items)) - - def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: - """Parse a single declaration into a ScheduleItem.""" - assert isinstance(value, dict) - # Try split declaration first (e.g., "axis[0:10]") - if ":" in declaration: - return self._parse_split(declaration, value) - - # Try tile declaration (e.g., "axis#32") - if "#" in declaration: - return self._parse_tile(declaration, value) - - # Must be a direct axis reference - return self._parse_axis_ref(declaration, value) - - def _parse_split(self, declaration: str, value: dict) -> SplitDecl: - """Parse a split declaration like 'axis[start:end]'.""" - axis_name, start, end = self._parse_split_syntax(declaration) - - body = self.parse(value) - return SplitDecl(axis=axis_name, start=start, end=end, body=body) - - def _parse_tile(self, declaration: str, value: dict) -> TileDecl: - """Parse a tile declaration like 'axis#size'.""" - parts = declaration.split("#") - if len(parts) != 2: - raise ScheduleParseError( - f"`{declaration}`: invalid tile syntax, expected 'axis#size'" - ) - - axis_name, size_str = parts - - try: - size = int(size_str) - except ValueError: - raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") - - annotations = self._parse_annotations(value, declaration) - return TileDecl(axis=axis_name, size=size, annotations=annotations) - - def _parse_axis_ref(self, declaration: str, value: dict) -> AxisDecl: - """Parse a direct axis reference.""" - - annotations = self._parse_annotations(value, declaration) - return AxisDecl(axis=declaration, annotations=annotations) - - def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: - """Parse annotation dict into Annotations object.""" - - unroll_factor: int | None = None - unroll_specified = False - vectorize = False - parallelize = False - - for key, param in value.items(): - if key == "unroll": - if param is True: - unroll_factor = None - unroll_specified = True - elif param is False: - pass - elif isinstance(param, int): - unroll_factor = param - unroll_specified = True - else: - raise ScheduleParseError( - f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, or an integer.' - ) - elif key == "vectorize": - if not isinstance(param, bool): - raise ScheduleParseError( - f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' - ) - vectorize = param - elif key == "parallelize": - if not isinstance(param, bool): - raise ScheduleParseError( - f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' - ) - parallelize = param - else: - raise ScheduleParseError(f"Unknown annotation on {context}: {key}") - - return Annotations( - unroll_factor=unroll_factor, - unroll_specified=unroll_specified, - vectorize=vectorize, - parallelize=parallelize, - ) - - def _parse_split_syntax( - self, declaration: str - ) -> tuple[str, int | None, int | None]: - """Parse the syntax of a split declaration.""" - match = self._SPLIT_PATTERN.match(declaration) - if not match: - raise ScheduleParseError(f"Wrong format {declaration}") - - prefix, x_str, y_str = match.groups() - x = int(x_str) if x_str else None - y = int(y_str) if y_str else None - return prefix, x, y + descript = Descript(scheduler=scheduler, abstract_dims=abstract_dims) + descript.apply(node_name=node_name, spec=spec) class ScheduleInterpreter: """Interprets a parsed ScheduleSpec AST into a LoopNest.""" - def __init__(self, abstract_axis: list[str]): - self.abstract_axis = abstract_axis - self.root_to_dim: dict[str, str] = {} - self.dim_to_axis: dict[str, str] = {} + def __init__(self, abstract_dims: list[str]): + self.abstract_dims = abstract_dims def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest: """Interpret a schedule specification into a LoopNest.""" - return self._interpret_spec(spec, root) - - def _interpret_spec(self, spec: ScheduleSpec, root: str) -> LoopNest: - """Interpret a schedule spec recursively.""" - loop_nest = LoopNest(abstract_dims=self.abstract_axis) - slice = loop_nest.build_slice(root) + loop_nest = LoopNest(abstract_dims=self.abstract_dims) + root_node = loop_nest.build_root_node(root) + self._interpret_spec_into_node(spec, root_node, root, head=[]) + return loop_nest + def _interpret_spec_into_node( + self, + spec: ScheduleSpec, + node: LoopNestNode, + root: str, + head: list[str], + ) -> None: + """Interpret a schedule spec into an existing node (mutates node).""" # Track state during interpretation sizes: dict[str, int] = {} - previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} - interchange: list[str] = [] - # Only the first root is not in root_to_dim - if root in self.root_to_dim: - interchange.append(self.root_to_dim[root]) + previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_dims} + interchange: list[str] = list(head) for item in spec.items: if isinstance(item, SplitDecl): - self._interpret_split( - item, slice, loop_nest, root, interchange, previous_cut - ) + self._interpret_split(item, node, root, interchange, previous_cut) elif isinstance(item, TileDecl): - loop_name = self._interpret_tile(item, slice, interchange, sizes) - self._apply_annotations(item.annotations, loop_name, sizes, slice) + loop_name = self._interpret_tile(item, node, interchange, sizes) + self._apply_annotations(item.annotations, loop_name, sizes, node) elif isinstance(item, AxisDecl): loop_name = self._interpret_axis(item, interchange) - self._apply_annotations(item.annotations, loop_name, sizes, slice) + self._apply_annotations(item.annotations, loop_name, sizes, node) + # Check that all splits are complete for axis, cut in previous_cut.items(): if cut is not None and cut != 0: @@ -259,14 +81,12 @@ def _interpret_spec(self, spec: ScheduleSpec, root: str) -> LoopNest: f"Splitting of {axis} unachieved (stops at {cut})." ) - slice.interchange = interchange - return loop_nest + node.interchange = interchange def _interpret_split( self, item: SplitDecl, - slice: LoopNestSlice, - loop_nest: LoopNest, + node: LoopNestNode, root: str, interchange: list[str], previous_cut: dict[str, int | None], @@ -281,7 +101,7 @@ def _interpret_split( # last one, so it cannot be the previous one. cut = previous_cut[axis_name] - # When x (the starting point of the slice) is not specified, + # When x (the starting point of the split) is not specified, # it is the previous cut if x is None: x = cut @@ -293,37 +113,46 @@ def _interpret_split( previous_cut[axis_name] = y # Save the cutting points of the new dimensions - if axis_name not in slice.splits: - slice.splits[axis_name] = {} - new_dim_index = len(slice.splits[axis_name]) + if axis_name not in node.splits: + node.splits[axis_name] = {} + new_dim_index = len(node.splits[axis_name]) new_dim_name = f"{axis_name}{SPLIT_LEFT_SEP}{new_dim_index}{SPLIT_RIGHT_SEP}" new_root_name = f"{root}{ROOT_SEP}{new_dim_name}" - slice.splits[axis_name][new_dim_name] = x + node.splits[axis_name][new_dim_name] = x interchange.append(new_dim_name) - self.dim_to_axis[new_dim_name] = axis_name - self.root_to_dim[new_root_name] = new_dim_name - # Recursively interpret the nested schedule - inner_nest = self._interpret_spec(item.body, new_root_name) - loop_nest.slices += inner_nest.slices + + # Create a child node for the nested schedule + child_node = LoopNestNode( + root=new_root_name, + tiles={a: {} for a in self.abstract_dims}, + split_origin=SplitOrigin(axis=axis_name, start=x, end=y), + ) + node.add_child(child_node) + + # Recursively interpret the nested schedule into the child node + self._interpret_spec_into_node( + item.body, child_node, new_root_name, head=[axis_name] + ) def _interpret_tile( self, item: TileDecl, - slice: LoopNestSlice, + node: LoopNestNode, interchange: list[str], sizes: dict[str, int], ) -> str: """Interpret a tile declaration. Returns the loop name.""" self._check_axis_existence(item.axis) - tile_num = len(slice.tiles[item.axis]) + tile_num = len(node.tiles[item.axis]) loop_name = f"{item.axis}{tile_num}" if item.size <= 0: raise ScheduleInterpretError( f"`{item}`: tile sizes should be strictly positive." ) - slice.tiles[item.axis][loop_name] = item.size + node.tiles[item.axis][loop_name] = item.size sizes[loop_name] = item.size interchange.append(loop_name) + return loop_name def _interpret_axis( @@ -334,22 +163,22 @@ def _interpret_axis( """Interpret a direct axis reference. Returns the loop name.""" axis_name = item.axis self._check_axis_existence(axis_name) + # Unreachable when built from a Python dict (because keys # can't be duplicated). - for loop_name in interchange: - if self.dim_to_axis.get(loop_name, loop_name) == axis_name: - raise ScheduleInterpretError( - f"Axis {axis_name} is scheduled twice (or more)." - ) + if axis_name in interchange: + raise ScheduleInterpretError( + f"Axis {axis_name} is scheduled twice (or more)." + ) interchange.append(axis_name) return axis_name def _check_axis_existence(self, axis: str) -> None: """Check that an axis is defined.""" - if axis not in self.abstract_axis: + if axis not in self.abstract_dims: raise ScheduleInterpretError( - f"Axis {axis} is not a defined axis (defined axis: {self.abstract_axis})." + f"Axis {axis} is not a defined axis (defined axis: {self.abstract_dims})." ) def _apply_annotations( @@ -357,9 +186,9 @@ def _apply_annotations( annotations: Annotations, loop_name: str, sizes: dict[str, int], - slice: LoopNestSlice, + node: LoopNestNode, ) -> None: - """Apply annotations to a loop in the slice.""" + """Apply annotations to a loop in the node.""" if annotations.unroll_specified: unroll_factor = annotations.unroll_factor if unroll_factor is None: @@ -373,13 +202,19 @@ def _apply_annotations( raise ScheduleInterpretError( f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' ) - slice.unroll[loop_name] = unroll_factor + node.unroll[loop_name] = unroll_factor if annotations.vectorize: - slice.vectorize.append(loop_name) + node.vectorize.append(loop_name) if annotations.parallelize: - slice.parallelize.append(loop_name) + node.parallelize.append(loop_name) + + if annotations.buffer_specified: + node.buffer_at[loop_name] = annotations.buffer + + if annotations.pack_specified and annotations.pack is not None: + node.pack_at[loop_name] = annotations.pack def _check_splitting_intervals( self, @@ -407,258 +242,6 @@ def _check_splitting_intervals( ) -@dataclass -class LoopsDimsMapper: - """Maps loop names to their corresponding axis names. - - This class tracks the relationship between loop identifiers (from tiling - and splitting transformations) and the original dimension axes they - derive from. - - Attributes: - tiles_to_axis: Maps tile loop names to their parent axis. - splits_to_axis: Maps split loop names to their parent axis. - dims: List of original dimension names. - """ - - tiles_to_axis: dict[str, str] - splits_to_axis: dict[str, str] - dims: list[str] - - @property - def loops_to_axis(self) -> dict[str, str]: - loops_to_axis = ( - self.tiles_to_axis | self.splits_to_axis | dict(zip(self.dims, self.dims)) - ) - return loops_to_axis - - @staticmethod - def build_from_slices(slices: list["LoopNestSlice"]) -> "LoopsDimsMapper": - tiles_to_axis = {} - splits_to_axis = {} - dims = set() - for slice in slices: - tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.tiles)) - splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.splits)) - refined_loops = list(tiles_to_axis) + list(splits_to_axis) - for slice in slices: - dims.update( - [loop for loop in slice.interchange if loop not in refined_loops] - ) - dims.update(tiles_to_axis.values()) - dims.update(splits_to_axis.values()) - return LoopsDimsMapper(tiles_to_axis, splits_to_axis, list(dims)) - - @staticmethod - def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]: - loop_to_axis: dict[str, str] = {} - for axis_name, subloops in subloops.items(): - for loop_name in subloops: - loop_to_axis[loop_name] = axis_name - return loop_to_axis - - -@dataclass -class LoopNestSlice: - """Represents a single slice of a loop nest with its transformations. - - A slice describes the loops attached to a single root and - contains all the scheduling transformations applied to these loops. - - Attributes: - root: Identifier of the the slice (either the base operation or - the content of a split). - tiles: Tiling configuration per axis. Maps axis names to dicts of - tile loop names and their sizes. - splits: Split configuration per axis. Maps axis names to dicts of - split loop names and their starting positions. - interchange: Ordered list of loop names defining the loop order. - vectorize: List of loops to vectorize. - parallelize: List of loops to parallelize. - unroll: Maps loop names to their unroll factors. - """ - - root: str - tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int | None]] = field(default_factory=dict) - interchange: list[str] = field(default_factory=list) - vectorize: list[str] = field(default_factory=list) - parallelize: list[str] = field(default_factory=list) - unroll: dict[str, int] = field(default_factory=dict) - - @property - def splits_to_sizes(self) -> dict[str, int | None]: - splits_to_sizes: dict[str, int | None] = {} - for axis in self.splits: - last_start = None - for loop_name, start in reversed(self.splits[axis].items()): - if last_start is not None and start is not None: - size_of_split = last_start - start - splits_to_sizes[loop_name] = size_of_split - else: - splits_to_sizes[loop_name] = None - last_start = start - return splits_to_sizes - - @property - def tiles_to_sizes(self) -> dict[str, int]: - tiles_to_sizes: dict[str, int] = {} - for tiles in self.tiles.values(): - for loop, size in tiles.items(): - tiles_to_sizes[loop] = size - return tiles_to_sizes - - -@dataclass -class LoopNest: - """Represents a complete loop nest structure for scheduling. - - A loop nest contains abstract dimensions and a collection of slices/ - It provides validation to ensure consistency across all slices. - - Attributes: - abstract_dims: List of abstract dimension names for the loop nest. - slices: List of LoopNestSlice objects, one per scheduled operation. - """ - - abstract_dims: list[str] - slices: list[LoopNestSlice] = field(default_factory=list) - - @property - def empty(self): - return not self.slices - - def build_slice(self, root: str) -> LoopNestSlice: - slice = LoopNestSlice(root=root, tiles={a: {} for a in self.abstract_dims}) - self.slices = [slice] + self.slices - return slice - - def check(self): - self._check_use_defined_dims() - self._check_vectorization_consistency() - self._check_tiling_consistency() - self._check_sizes() - - def _check_use_defined_dims(self): - mapper = LoopsDimsMapper.build_from_slices(self.slices) - for dim in self.abstract_dims: - if dim not in mapper.dims: - raise ScheduleValidationError(f"{dim} defined but never used") - - def _check_vectorization_consistency(self): - for sched in self.slices: - vect_above = False - for loop_name in sched.interchange: - if loop_name in sched.vectorize: - vect_above = True - elif vect_above: - raise ScheduleValidationError( - f"Inner loop {loop_name} isn't vectorized but an outer one is." - ) - - def _check_tiling_consistency(self) -> None: - mapper = LoopsDimsMapper.build_from_slices(self.slices) - seen_axes: dict[str, int | None] = {} - for sched in self.slices: - for loop_name in sched.interchange: - loop_name = mapper.splits_to_axis.get(loop_name, loop_name) - - if loop_name in mapper.dims: - seen_axes[loop_name] = None - elif loop_name in mapper.tiles_to_axis: - axis = mapper.tiles_to_axis[loop_name] - size = sched.tiles_to_sizes[loop_name] - if axis not in seen_axes: - raise ScheduleValidationError( - f""" - `{axis}#{size}`: {axis} has not been materialized yet. - """ - ) - seen_axes[axis] = sched.tiles[axis][loop_name] - - def _check_sizes(self): - mapper = LoopsDimsMapper.build_from_slices(self.slices) - current_size_of_split: dict[str, int | None] = {} - for sched in self.slices: - current_size_of_tile: dict[str, int] = {} - for loop_name in sched.interchange: - axis = mapper.loops_to_axis[loop_name] - current_sizes = ( - {d: None for d in mapper.dims} - | current_size_of_split - | current_size_of_tile - ) - loop_size = None - if loop_name in mapper.dims: - if loop_name not in current_size_of_split: - current_size_of_split[loop_name] = None - elif loop_name in mapper.tiles_to_axis: - loop_size = sched.tiles[axis][loop_name] - LoopNest._must_be_smaller_routine( - new_size=loop_size, - current_sizes=current_sizes, - loop_name=loop_name, - axis=axis, - ) - current_size_of_tile[axis] = loop_size - elif ( - loop_name in mapper.splits_to_axis - and loop_name in sched.splits_to_sizes - ): - loop_size = sched.splits_to_sizes[loop_name] - LoopNest._must_be_smaller_routine( - new_size=loop_size, - current_sizes=current_sizes, - loop_name=loop_name, - axis=axis, - ) - current_size_of_split[loop_name] = loop_size - elif loop_name in current_size_of_split: - current_size_of_split[axis] = current_size_of_split[loop_name] - - if loop_name in sched.unroll: - unroll_factor = sched.unroll[loop_name] - if loop_size and loop_size < unroll_factor: - raise ScheduleValidationError( - f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' - ) - - @staticmethod - def _must_be_smaller_routine( - new_size: int | None, - current_sizes: dict[str, int | None], - loop_name: str, - axis: str, - ): - old_size = current_sizes[axis] - if old_size is not None and new_size is not None and new_size > old_size: - raise ScheduleValidationError( - f""" - Inner loop {loop_name} on axis {axis} must be smaller than outer loop. - """ - ) - - -def descript_scheduler( - scheduler: Scheduler, - node_name: str, - abstract_axis: list[str], - spec: dict[str, dict[str, Any]], -) -> None: - """Apply a schedule specification to a scheduler. - - This is the main entry point for using the descript scheduling DSL. - - Args: - scheduler: The scheduler to apply the schedule to. - node_name: The name of the root node to schedule. - abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]). - spec: The schedule specification as a nested dict. - """ - descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis) - descript.apply(node_name=node_name, spec=spec) - - @dataclass(frozen=True) class Descript: """Applies a parsed and interpreted schedule to a Scheduler. @@ -672,7 +255,7 @@ class Descript: """ scheduler: Scheduler - abstract_axis: list[str] + abstract_dims: list[str] def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: """Parse, interpret, validate, and apply a schedule specification. @@ -687,11 +270,11 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: ScheduleValidationError: If the resulting schedule is invalid. """ # Parse the specification into an AST - parser = ScheduleParser(self.abstract_axis) + parser = ScheduleParser() ast = parser.parse(spec) # Interpret the AST into a LoopNest - interpreter = ScheduleInterpreter(self.abstract_axis) + interpreter = ScheduleInterpreter(self.abstract_dims) loop_nest = interpreter.interpret(ast, root=node_name) # Validate the loop nest loop_nest.check() @@ -700,18 +283,32 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: def _apply_loop_nest(self, loop_nest: LoopNest) -> None: """Apply a LoopNest to the scheduler.""" - self.scheduler.set_dims(self.abstract_axis) + self.scheduler.set_dims(self.abstract_dims) + + if loop_nest.root_node is not None: + self._apply_node(loop_nest.root_node) + + def _apply_node(self, node: LoopNestNode) -> None: + """Recursively apply a LoopNestNode and its children to the scheduler.""" + root = node.root + + for d, s in node.splits.items(): + self.scheduler.split(d, s, root=root) + + for d, s in node.tiles.items(): + self.scheduler.tile(d, s, root=root) - for slice in loop_nest.slices: - root = slice.root + self.scheduler.interchange(node.interchange, root=root) + self.scheduler.vectorize(node.vectorize, root=root) + self.scheduler.parallelize(node.parallelize, root=root) + self.scheduler.unroll(node.unroll, root=root) - for d, s in slice.splits.items(): - self.scheduler.split(d, s, root=root) + for axis, mtype in node.buffer_at.items(): + self.scheduler.buffer_at(axis, mtype=mtype, root=root) - for d, s in slice.tiles.items(): - self.scheduler.tile(d, s, root=root) + for axis, (input_idx, mtype, pad) in node.pack_at.items(): + self.scheduler.pack_at(axis, input_idx, mtype=mtype, pad=pad, root=root) - self.scheduler.interchange(slice.interchange, root=root) - self.scheduler.vectorize(slice.vectorize, root=root) - self.scheduler.parallelize(slice.parallelize, root=root) - self.scheduler.unroll(slice.unroll, root=root) + # Recursively apply children + for child in node.children: + self._apply_node(child) diff --git a/src/xtc/schedules/exceptions.py b/src/xtc/schedules/exceptions.py new file mode 100644 index 00000000..84d51912 --- /dev/null +++ b/src/xtc/schedules/exceptions.py @@ -0,0 +1,23 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +"""Schedule-related exceptions.""" + + +class ScheduleParseError(RuntimeError): + """Raised when schedule parsing fails.""" + + pass + + +class ScheduleInterpretError(RuntimeError): + """Raised when schedule interpretation fails.""" + + pass + + +class ScheduleValidationError(RuntimeError): + """Raised when schedule validation fails.""" + + pass diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py new file mode 100644 index 00000000..306c0c91 --- /dev/null +++ b/src/xtc/schedules/loop_nest.py @@ -0,0 +1,459 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from __future__ import annotations + +from typing import Generic, TypeVar +from dataclasses import dataclass, field + +from .exceptions import ScheduleValidationError + + +@dataclass +class SplitOrigin: + """Describes how a node was created via a split from its parent. + + Attributes: + axis: The axis that was split to create this node. + start: The starting position of the split (inclusive), or None if unbounded. + end: The ending position of the split (exclusive), or None if unbounded. + """ + + axis: str + start: int | None + end: int | None + + +NodeT = TypeVar("NodeT", bound="Node") + + +@dataclass(kw_only=True) +class Node(Generic[NodeT]): + """Base class for tree nodes with parent/child relationships. + + Provides tree structure and traversal operations. Subclasses add + domain-specific data. + + Attributes: + parent: Reference to the parent node, or None for the root. + split_origin: Metadata describing how this node was created from + its parent via a split. None for the root node. + children: List of child nodes. + """ + + parent: NodeT | None = None + split_origin: SplitOrigin | None = None + children: list[NodeT] = field(default_factory=list) + + @property + def is_root(self) -> bool: + """Returns True if this node is the root (has no parent).""" + return self.parent is None + + def add_child(self, child: NodeT) -> None: + """Add a child node and set its parent to this node.""" + child.parent = self # type: ignore[assignment] + self.children.append(child) + + def ancestors(self) -> list[NodeT]: + """Return list of ancestors from parent to root.""" + result: list[NodeT] = [] + current = self.parent + while current is not None: + result.append(current) + current = current.parent + return result + + def descendants_dfs(self) -> list[NodeT]: + """Return all descendants in depth-first order.""" + result: list[NodeT] = [] + for child in self.children: + result.append(child) + result.extend(child.descendants_dfs()) + return result + + +@dataclass +class LoopNestNode(Node["LoopNestNode"]): + """Represents a node in the loop nest tree with its transformations. + + Describes the loops attached to a single root and + contains all the scheduling transformations applied to these loops. + + Attributes: + root: Identifier of the node (either the base operation or + the content of a split). + tiles: Tiling configuration per axis. Maps axis names to dicts of + tile loop names and their sizes. + splits: Split configuration per axis. Maps axis names to dicts of + split loop names and their starting positions. + interchange: Ordered list of loop names defining the loop order. + vectorize: List of loops to vectorize. + parallelize: List of loops to parallelize. + unroll: Maps loop names to their unroll factors. + buffer_at: Buffer configuration per axis. Maps axis names to optional + memory types (mtype). None means default memory type. + pack_at: Pack configuration per axis. Maps axis names to tuples of + (input_idx, mtype, pad). input_idx is the input buffer index, + mtype is the memory type (None for default), pad enables padding. + """ + + root: str + tiles: dict[str, dict[str, int]] + splits: dict[str, dict[str, int]] = field(default_factory=dict) + interchange: list[str] = field(default_factory=list) + vectorize: list[str] = field(default_factory=list) + parallelize: list[str] = field(default_factory=list) + unroll: dict[str, int] = field(default_factory=dict) + buffer_at: dict[str, str | None] = field(default_factory=dict) + pack_at: dict[str, tuple[int, str | None, bool]] = field(default_factory=dict) + + def pretty_print(self, indent: int = 0) -> str: + """Return a human-readable representation of the loop nest. + + The output format uses a compact notation: + - `loop X` for a regular loop over dimension X + - `tile(X, N)` for a tile of size N on dimension X + - `split(X, start, end)` for a split segment on dimension X + - `// annotation` for vectorized, parallelized, unroll(N) + - `...` for the innermost body + + Example output: + loop i // parallelized + loop k + loop j + tile(j, 16) // vectorized + ... + + Args: + indent: The initial indentation level (number of spaces). + + Returns: + A multi-line string representing the loop nest structure. + """ + lines: list[str] = [] + + mapper = LoopInfo.build_from_node(self) + tiles_info = mapper.tiles_info + splits_info = mapper.splits_info + + # Map split loop names to their child nodes + split_to_child: dict[str, LoopNestNode] = {} + for child in self.children: + if child.split_origin is not None: + axis = child.split_origin.axis + if axis in self.splits: + for loop_name, start in self.splits[axis].items(): + if start == child.split_origin.start: + split_to_child[loop_name] = child + break + + # Group splits by axis for same-level printing + axis_to_splits: dict[str, list[str]] = {} + for loop_name, (axis, _, _) in splits_info.items(): + if loop_name in self.interchange: + if axis not in axis_to_splits: + axis_to_splits[axis] = [] + axis_to_splits[axis].append(loop_name) + + processed_splits: set[str] = set() + current_indent = indent + + for loop_name in self.interchange: + # Skip already processed splits + if loop_name in processed_splits: + continue + + # Check if this is a split + if loop_name in splits_info: + axis, _, _ = splits_info[loop_name] + axis_split_names = axis_to_splits.get(axis, [loop_name]) + processed_splits.update(axis_split_names) + + # Print all splits of this axis at the same level + for split_name in axis_split_names: + split_axis, start, end = splits_info[split_name] + end_str = str(end) if end is not None else "..." + line = f"split({split_axis}, {start}, {end_str})" + line = self._add_annotations(line, split_name) + lines.append(" " * current_indent + line) + + # Use child's pretty_print if available + if split_name in split_to_child: + child_output = split_to_child[split_name].pretty_print( + current_indent + 2 + ) + lines.append(child_output) + else: + lines.append(" " * (current_indent + 2) + "...") + else: + # Regular loop (tile or base dimension) + if loop_name in tiles_info: + axis, size = tiles_info[loop_name] + line = f"tile({axis}, {size})" + else: + # Extract basename (last part after /) + basename = loop_name.split("/")[-1] + line = f"loop {basename}" + + line = self._add_annotations(line, loop_name) + lines.append(" " * current_indent + line) + current_indent += 2 + + # Add body if no splits were encountered + if not processed_splits: + lines.append(" " * current_indent + "...") + + return "\n".join(lines) + + def _add_annotations(self, line: str, loop_name: str) -> str: + """Add annotations (parallelized, vectorized, unroll, buffer, pack) to a loop line.""" + annotations: list[str] = [] + if loop_name in self.parallelize: + annotations.append("parallelized") + if loop_name in self.vectorize: + annotations.append("vectorized") + if loop_name in self.unroll: + annotations.append(f"unroll({self.unroll[loop_name]})") + if loop_name in self.buffer_at: + mtype = self.buffer_at[loop_name] + if mtype is not None: + annotations.append(f"buffer({mtype})") + else: + annotations.append("buffer") + if loop_name in self.pack_at: + input_idx, mtype, pad = self.pack_at[loop_name] + parts = [str(input_idx)] + if mtype is not None: + parts.append(mtype) + if pad: + parts.append("pad") + annotations.append(f"pack({', '.join(parts)})") + if annotations: + line += " // " + ", ".join(annotations) + return line + + +@dataclass +class LoopInfo: + """Maps loop names to their corresponding axis names and metadata. + + This class tracks the relationship between loop identifiers (from tiling + and splitting transformations) and the original dimension axes they + derive from, along with their sizes and positions. + + Attributes: + dims: List of original dimension names. + tiles_info: Maps tile loop names to (axis, size) tuples. + splits_info: Maps split loop names to (axis, start, end) tuples. + """ + + dims: list[str] + tiles_info: dict[str, tuple[str, int]] = field(default_factory=dict) + splits_info: dict[str, tuple[str, int, int | None]] = field(default_factory=dict) + + @property + def tiles_to_axis(self) -> dict[str, str]: + return {name: axis for name, (axis, _) in self.tiles_info.items()} + + @property + def splits_to_axis(self) -> dict[str, str]: + return {name: axis for name, (axis, _, _) in self.splits_info.items()} + + @property + def loops_to_axis(self) -> dict[str, str]: + return ( + self.tiles_to_axis | self.splits_to_axis | dict(zip(self.dims, self.dims)) + ) + + @property + def splits_to_sizes(self) -> dict[str, int]: + return { + name: end - start + for name, (_, start, end) in self.splits_info.items() + if end is not None + } + + @staticmethod + def build_from_node(node: LoopNestNode) -> LoopInfo: + tiles_info: dict[str, tuple[str, int]] = {} + splits_info: dict[str, tuple[str, int, int | None]] = {} + dims: dict[ + str, None + ] = {} # ordered set: insertion order preserved, no duplicates + + def collect(n: LoopNestNode) -> None: + # Build tiles_info: tile_name -> (axis, size) + for axis, tile_loops in n.tiles.items(): + for loop_name, size in tile_loops.items(): + tiles_info[loop_name] = (axis, size) + + # Build splits_info: split_name -> (axis, start, end) + for axis, axis_splits in n.splits.items(): + sorted_splits = sorted(axis_splits.items(), key=lambda kv: kv[1]) + for i, (loop_name, start) in enumerate(sorted_splits): + end = ( + sorted_splits[i + 1][1] if i + 1 < len(sorted_splits) else None + ) + splits_info[loop_name] = (axis, start, end) + + # Collect dims in stable order + refined_loops = set(tiles_info) | set(splits_info) + for loop in n.interchange: + if loop not in refined_loops: + dims[loop] = None + for axis, _ in tiles_info.values(): + dims[axis] = None + for axis, _, _ in splits_info.values(): + dims[axis] = None + + # Recurse on children + for child in n.children: + collect(child) + + collect(node) + + return LoopInfo(list(dims), tiles_info, splits_info) + + +@dataclass +class LoopNest: + """Represents a complete loop nest structure for scheduling. + + A loop nest contains abstract dimensions and a tree of nodes representing + the schedule. Splits create child nodes, forming an explicit tree structure. + + Attributes: + abstract_dims: List of abstract dimension names for the loop nest. + root_node: The root node of the loop nest tree, or None if empty. + """ + + abstract_dims: list[str] + root_node: LoopNestNode | None = None + + @property + def nodes(self) -> list[LoopNestNode]: + """Flatten the tree into a list of nodes. + + Returns nodes in depth-first order, with the root node first, + followed by children in the order they were created. + """ + if self.root_node is None: + return [] + return [self.root_node] + self.root_node.descendants_dfs() + + def build_root_node(self, root: str) -> LoopNestNode: + """Build and set the root node of the loop nest tree.""" + node = LoopNestNode(root=root, tiles={a: {} for a in self.abstract_dims}) + self.root_node = node + return node + + def check(self): + assert self.root_node is not None + info = LoopInfo.build_from_node(self.root_node) + self._check_use_defined_dims(info) + self._check_vectorization_consistency() + self._check_tiling_consistency(info) + self._check_sizes(info) + + def _check_use_defined_dims(self, info: LoopInfo): + for dim in self.abstract_dims: + if dim not in info.dims: + raise ScheduleValidationError(f"{dim} defined but never used") + + def _check_vectorization_consistency(self): + for sched in self.nodes: + vect_above = False + for loop_name in sched.interchange: + if loop_name in sched.vectorize: + vect_above = True + elif vect_above: + raise ScheduleValidationError( + f"Inner loop {loop_name} isn't vectorized but an outer one is." + ) + + def _check_tiling_consistency(self, info: LoopInfo) -> None: + seen_axes: dict[str, int | None] = {} + for sched in self.nodes: + for loop_name in sched.interchange: + if loop_name in info.dims: + seen_axes[loop_name] = None + elif loop_name in info.splits_to_axis: + axis = info.splits_to_axis[loop_name] + seen_axes[axis] = sched.splits[axis][loop_name] + elif loop_name in info.tiles_to_axis: + axis = info.tiles_to_axis[loop_name] + size = sched.tiles[axis][loop_name] + if axis not in seen_axes: + raise ScheduleValidationError( + f""" + `{axis}#{size}`: {axis} has not been materialized yet. + """ + ) + seen_axes[axis] = size + + def _check_sizes(self, info: LoopInfo): + current_size_of_split: dict[str, int | None] = {} + for sched in self.nodes: + current_size_of_tile: dict[str, int] = {} + if sched.split_origin is not None: + axis = sched.split_origin.axis + start = sched.split_origin.start + end = sched.split_origin.end + if end is not None and start is not None: + current_size_of_split[axis] = end - start + else: + current_size_of_split[axis] = None + + for loop_name in sched.interchange: + axis = info.loops_to_axis[loop_name] + current_sizes = ( + {d: None for d in info.dims} + | current_size_of_split + | current_size_of_tile + ) + loop_size = None + if loop_name in info.dims: + if loop_name not in current_size_of_split: + current_size_of_split[loop_name] = None + elif loop_name in info.tiles_to_axis: + loop_size = sched.tiles[axis][loop_name] + LoopNest._must_be_smaller_routine( + new_size=loop_size, + current_sizes=current_sizes, + loop_name=loop_name, + axis=axis, + ) + current_size_of_tile[axis] = loop_size + elif ( + loop_name in info.splits_to_axis + and loop_name in info.splits_to_sizes + ): + loop_size = info.splits_to_sizes[loop_name] + LoopNest._must_be_smaller_routine( + new_size=loop_size, + current_sizes=current_sizes, + loop_name=loop_name, + axis=axis, + ) + current_size_of_split[axis] = loop_size + + if loop_name in sched.unroll: + unroll_factor = sched.unroll[loop_name] + if loop_size and loop_size < unroll_factor: + raise ScheduleValidationError( + f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' + ) + + @staticmethod + def _must_be_smaller_routine( + new_size: int, current_sizes: dict[str, int | None], loop_name: str, axis: str + ): + old_size = current_sizes[axis] + if old_size is not None and new_size > old_size: + raise ScheduleValidationError( + f""" + Inner loop {loop_name} on axis {axis} must be smaller than outer loop. + """ + ) diff --git a/src/xtc/schedules/parsing.py b/src/xtc/schedules/parsing.py new file mode 100644 index 00000000..d1e17a53 --- /dev/null +++ b/src/xtc/schedules/parsing.py @@ -0,0 +1,256 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from __future__ import annotations + +from typing import Any +from dataclasses import dataclass +import re +from typing_extensions import override + +from .exceptions import ScheduleParseError + + +@dataclass(frozen=True) +class Annotations: + """AST Type : annotations that can be applied to a loop. + + Attributes: + unroll_factor: The unroll factor. None means "unroll fully" (use loop size). + Only meaningful when unroll_specified is True. + unroll_specified: True if unroll was explicitly requested. + vectorize: True if vectorization was requested. + parallelize: True if parallelization was requested. + buffer: The memory type for the buffer. None means default memory type. + Only meaningful when buffer_specified is True. + buffer_specified: True if buffer was explicitly requested. + pack: Pack configuration as (input_idx, mtype, pad). mtype is None for default. + Only meaningful when pack_specified is True. + pack_specified: True if pack was explicitly requested. + """ + + unroll_factor: int | None = None + unroll_specified: bool = False + vectorize: bool = False + parallelize: bool = False + buffer: str | None = None + buffer_specified: bool = False + pack: tuple[int, str | None, bool] | None = None + pack_specified: bool = False + + +@dataclass(frozen=True) +class SplitDecl: + """AST Type: a split declaration like 'axis[start:end]'.""" + + axis: str + start: int | None + end: int | None + body: ScheduleSpec + + @override + def __str__(self) -> str: + start_str = "" if self.start is None else str(self.start) + end_str = "" if self.end is None else str(self.end) + decl = f"{self.axis}[{start_str}:{end_str}]" + return decl + + +@dataclass(frozen=True) +class TileDecl: + """AST Type: a tile declaration like 'axis#size'.""" + + axis: str + size: int + annotations: Annotations + + @override + def __str__(self) -> str: + return f"{self.axis}#{self.size}" + + +@dataclass(frozen=True) +class AxisDecl: + """AST Type: a direct axis reference.""" + + axis: str + annotations: Annotations + + +ScheduleItem = SplitDecl | TileDecl | AxisDecl + + +@dataclass(frozen=True) +class ScheduleSpec: + """AST Type: the complete parsed schedule specification.""" + + items: tuple[ScheduleItem, ...] + + +class ScheduleParser: + """Parses a dict-based schedule specification into an AST.""" + + _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$") + + def parse(self, spec: dict[str, Any]) -> ScheduleSpec: + """Parse a schedule specification dict into an AST.""" + items: list[ScheduleItem] = [] + + for declaration, value in spec.items(): + item = self._parse_declaration(declaration, value) + items.append(item) + + return ScheduleSpec(items=tuple(items)) + + def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: + """Parse a single declaration into a ScheduleItem.""" + assert isinstance(value, dict) + # Try split declaration first (e.g., "axis[0:10]") + if ":" in declaration: + return self._parse_split(declaration, value) + + # Try tile declaration (e.g., "axis#32") + if "#" in declaration: + return self._parse_tile(declaration, value) + + # Must be a direct axis reference + return self._parse_axis_ref(declaration, value) + + def _parse_split(self, declaration: str, value: dict) -> SplitDecl: + """Parse a split declaration like 'axis[start:end]'.""" + axis_name, start, end = self._parse_split_syntax(declaration) + + body = self.parse(value) + return SplitDecl(axis=axis_name, start=start, end=end, body=body) + + def _parse_tile(self, declaration: str, value: dict) -> TileDecl: + """Parse a tile declaration like 'axis#size'.""" + parts = declaration.split("#") + if len(parts) != 2: + raise ScheduleParseError( + f"`{declaration}`: invalid tile syntax, expected 'axis#size'" + ) + + axis_name, size_str = parts + + try: + size = int(size_str) + except ValueError: + raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") + + annotations = self._parse_annotations(value, declaration) + return TileDecl(axis=axis_name, size=size, annotations=annotations) + + def _parse_axis_ref(self, declaration: str, value: dict) -> AxisDecl: + """Parse a direct axis reference.""" + + annotations = self._parse_annotations(value, declaration) + return AxisDecl(axis=declaration, annotations=annotations) + + def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: + """Parse annotation dict into Annotations object.""" + + unroll_factor: int | None = None + unroll_specified = False + vectorize = False + parallelize = False + buffer: str | None = None + buffer_specified = False + pack: tuple[int, str | None, bool] | None = None + pack_specified = False + + for key, param in value.items(): + if key == "unroll": + if param is True: + unroll_factor = None + unroll_specified = True + elif param is False: + pass + elif isinstance(param, int): + unroll_factor = param + unroll_specified = True + else: + raise ScheduleParseError( + f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, or an integer.' + ) + elif key == "vectorize": + if not isinstance(param, bool): + raise ScheduleParseError( + f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' + ) + vectorize = param + elif key == "parallelize": + if not isinstance(param, bool): + raise ScheduleParseError( + f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' + ) + parallelize = param + elif key == "buffer": + if not isinstance(param, str): + raise ScheduleParseError( + f'`{{"buffer" = {param}}}`: buffer parameter should be a string (mtype).' + ) + buffer = None if param == "default" else param + buffer_specified = True + elif key == "pack": + pack = self._parse_pack_param(param, context) + pack_specified = True + else: + raise ScheduleParseError(f"Unknown annotation on {context}: {key}") + + return Annotations( + unroll_factor=unroll_factor, + unroll_specified=unroll_specified, + vectorize=vectorize, + parallelize=parallelize, + buffer=buffer, + buffer_specified=buffer_specified, + pack=pack, + pack_specified=pack_specified, + ) + + def _parse_pack_param( + self, param: Any, context: str + ) -> tuple[int, str | None, bool]: + """Parse pack parameter into (input_idx, mtype, pad) tuple.""" + if not isinstance(param, (list, tuple)) or len(param) != 3: + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: pack parameter should be a tuple (input_idx, mtype, pad).' + ) + + input_idx, mtype, pad = param + + if not isinstance(input_idx, int): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: input_idx should be an integer.' + ) + + if mtype is not None and not isinstance(mtype, str): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: mtype should be a string or None.' + ) + + if not isinstance(pad, bool): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: pad should be a boolean.' + ) + + # Convert "default" to None for mtype + if mtype == "default": + mtype = None + + return (input_idx, mtype, pad) + + def _parse_split_syntax( + self, declaration: str + ) -> tuple[str, int | None, int | None]: + """Parse the syntax of a split declaration.""" + match = self._SPLIT_PATTERN.match(declaration) + if not match: + raise ScheduleParseError(f"Wrong format {declaration}") + + prefix, x_str, y_str = match.groups() + x = int(x_str) if x_str else None + y = int(y_str) if y_str else None + return prefix, x, y diff --git a/src/xtc/schedules/ttile/scheme_to_xtc.py b/src/xtc/schedules/ttile/scheme_to_xtc.py index f2c83364..99603bfb 100644 --- a/src/xtc/schedules/ttile/scheme_to_xtc.py +++ b/src/xtc/schedules/ttile/scheme_to_xtc.py @@ -723,7 +723,7 @@ def build_schedule_from_ttile( ) descript_scheduler( - scheduler=sch, node_name=name_op, abstract_axis=ldims, spec=spec_schedule + scheduler=sch, node_name=name_op, abstract_dims=ldims, spec=spec_schedule ) # And run it! diff --git a/tests/filecheck/schedules/test_descript_parsing_errors.py b/tests/filecheck/schedules/test_descript_parsing_errors.py index 49f0322b..676d5b81 100644 --- a/tests/filecheck/schedules/test_descript_parsing_errors.py +++ b/tests/filecheck/schedules/test_descript_parsing_errors.py @@ -24,7 +24,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {}, "K": {"unroll" : "hello"}, @@ -35,7 +35,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {"parallelize" : "hello"}, "K": {}, @@ -46,7 +46,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {}, "K": {}, diff --git a/tests/filecheck/schedules/test_descript_pretty_print.py b/tests/filecheck/schedules/test_descript_pretty_print.py new file mode 100644 index 00000000..54ed443d --- /dev/null +++ b/tests/filecheck/schedules/test_descript_pretty_print.py @@ -0,0 +1,125 @@ +# RUN: python %s --simple 2>&1 | filecheck %s --check-prefix=CHECK-SIMPLE +# RUN: python %s --tiled 2>&1 | filecheck %s --check-prefix=CHECK-TILED +# RUN: python %s --vectorized 2>&1 | filecheck %s --check-prefix=CHECK-VECTORIZED +# RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL +# RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT +# RUN: python %s --buffer 2>&1 | filecheck %s --check-prefix=CHECK-BUFFER +# RUN: python %s --pack 2>&1 | filecheck %s --check-prefix=CHECK-PACK + +import sys +from xtc.schedules.parsing import ScheduleParser +from xtc.schedules.descript import ScheduleInterpreter + +parser = ScheduleParser() +abstract_axis = ["i", "j", "k"] +interpreter = ScheduleInterpreter(abstract_axis) + +if "--simple" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--tiled" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}, "j#16": {}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--vectorized" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}, "j#16": {"vectorize": True}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--full" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {}, + "j": {}, + "j#32": {}, + "j#16": {"vectorize": True, "unroll": 4}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--split" in sys.argv: + spec = { + "i": {}, + "j[:128]": {"k": {}, "k#32": {}}, + "j[128:]": {"k": {}, "k#16": {"vectorize": True}}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--buffer" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {"buffer": "default"}, + "j": {"buffer": "shared"}, + "j#16": {"vectorize": True}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--pack" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {"pack": (0, "default", False)}, + "j": {"pack": (1, "shared", True)}, + "j#16": {"vectorize": True}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +# CHECK-SIMPLE: loop i +# CHECK-SIMPLE-NEXT: loop k +# CHECK-SIMPLE-NEXT: loop j +# CHECK-SIMPLE-NEXT: ... + +# CHECK-TILED: loop i +# CHECK-TILED-NEXT: loop k +# CHECK-TILED-NEXT: loop j +# CHECK-TILED-NEXT: tile(j, 16) +# CHECK-TILED-NEXT: ... + +# CHECK-VECTORIZED: loop i +# CHECK-VECTORIZED-NEXT: loop k +# CHECK-VECTORIZED-NEXT: loop j +# CHECK-VECTORIZED-NEXT: tile(j, 16) // vectorized +# CHECK-VECTORIZED-NEXT: ... + +# CHECK-FULL: loop i // parallelized +# CHECK-FULL-NEXT: loop k +# CHECK-FULL-NEXT: loop j +# CHECK-FULL-NEXT: tile(j, 32) +# CHECK-FULL-NEXT: tile(j, 16) // vectorized, unroll(4) +# CHECK-FULL-NEXT: ... + +# CHECK-SPLIT: loop i +# CHECK-SPLIT-NEXT: split(j, 0, 128) +# CHECK-SPLIT-NEXT: loop j +# CHECK-SPLIT-NEXT: loop k +# CHECK-SPLIT-NEXT: tile(k, 32) +# CHECK-SPLIT-NEXT: ... +# CHECK-SPLIT-NEXT: split(j, 128, ...) +# CHECK-SPLIT-NEXT: loop j +# CHECK-SPLIT-NEXT: loop k +# CHECK-SPLIT-NEXT: tile(k, 16) // vectorized +# CHECK-SPLIT-NEXT: ... + +# CHECK-BUFFER: loop i // parallelized +# CHECK-BUFFER-NEXT: loop k // buffer +# CHECK-BUFFER-NEXT: loop j // buffer(shared) +# CHECK-BUFFER-NEXT: tile(j, 16) // vectorized +# CHECK-BUFFER-NEXT: ... + +# CHECK-PACK: loop i // parallelized +# CHECK-PACK-NEXT: loop k // pack(0) +# CHECK-PACK-NEXT: loop j // pack(1, shared, pad) +# CHECK-PACK-NEXT: tile(j, 16) // vectorized +# CHECK-PACK-NEXT: ... diff --git a/tests/filecheck/schedules/test_descript_slice_bigger.py b/tests/filecheck/schedules/test_descript_slice_bigger.py index 61f2090e..f295916a 100644 --- a/tests/filecheck/schedules/test_descript_slice_bigger.py +++ b/tests/filecheck/schedules/test_descript_slice_bigger.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler=sch, node_name="C", - abstract_axis=["i", "j", "k"], + abstract_dims=["i", "j", "k"], spec={ 'k': {}, 'j': {}, diff --git a/tests/filecheck/schedules/test_descript_slice_smaller.py b/tests/filecheck/schedules/test_descript_slice_smaller.py index bdf11093..551f001b 100644 --- a/tests/filecheck/schedules/test_descript_slice_smaller.py +++ b/tests/filecheck/schedules/test_descript_slice_smaller.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler=sch, node_name="C", - abstract_axis=["i", "j", "k"], + abstract_dims=["i", "j", "k"], spec={ 'k': {}, 'j': {}, diff --git a/tests/filecheck/schedules/test_get_descript.py b/tests/filecheck/schedules/test_get_descript.py new file mode 100644 index 00000000..47888bb4 --- /dev/null +++ b/tests/filecheck/schedules/test_get_descript.py @@ -0,0 +1,81 @@ +# RUN: python %s --mlir 2>&1 | filecheck %s --check-prefix=CHECK-MLIR +# RUN: python %s --tvm 2>&1 | filecheck %s --check-prefix=CHECK-TVM +# REQUIRES: module_tvm + +import sys +import xtc.graphs.xtc.op as O + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph + +if "--mlir" in sys.argv: + from xtc.backends.mlir import Backend + +elif "--tvm" in sys.argv: + from xtc.backends.tvm import Backend + +else: + assert False + +impl = Backend(graph) +sch = impl.get_scheduler() +sch.set_dims(["I", "J", "K"]) +sch.tile("I", {"I0": 2}) +sch.tile("J", {"J0": 16}) +sch.interchange(["K", "I", "J", "I0", "J0"]) +sch.unroll({"I0": 2}) +sch.vectorize(["J0"]) +if "--tvm" in sys.argv: + sch.buffer_at("J") + sch.pack_at("I", 0, pad=True) + +loop_nest = sch.get_loop_nest() +print(loop_nest.root_node.pretty_print()) + +# CHECK-MLIR: loop K +# CHECK-MLIR-NEXT: loop I +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(I, 2) // unroll(2) +# CHECK-MLIR-NEXT: tile(J, 16) // vectorized +# CHECK-MLIR-NEXT: ... + +# CHECK-TVM: loop K +# CHECK-TVM-NEXT: loop I // pack(0, pad) +# CHECK-TVM-NEXT: loop J // buffer +# CHECK-TVM-NEXT: tile(I, 2) // unroll(2) +# CHECK-TVM-NEXT: tile(J, 16) // vectorized +# CHECK-TVM-NEXT: ... + +# Test with split (MLIR only - TVM does not support split) +if "--mlir" in sys.argv: + print("---") + + impl2 = Backend(graph) + sch2 = impl2.get_scheduler() + sch2.set_dims(["I", "J", "K"]) + sch2.split("I", {"I_lo": 0, "I_hi": 2}) + sch2.tile("J", {"J0": 16}, root="./I_lo") + sch2.tile("J", {"J0": 16}, root="./I_hi") + sch2.interchange(["K", "I_lo", "I_hi"]) + sch2.interchange(["J", "J0"], root="./I_lo") + sch2.interchange(["J", "J0"], root="./I_hi") + + loop_nest2 = sch2.get_loop_nest() + print(loop_nest2.root_node.pretty_print()) + +# CHECK-MLIR: --- +# CHECK-MLIR-NEXT: loop K +# CHECK-MLIR-NEXT: split(I, 0, 2) +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(J, 16) +# CHECK-MLIR-NEXT: ... +# CHECK-MLIR-NEXT: split(I, 2, ...) +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(J, 16) +# CHECK-MLIR-NEXT: ... diff --git a/tests/filecheck/schedules/test_matmul_descript_mlir.py b/tests/filecheck/schedules/test_matmul_descript_mlir.py index f3f517a5..814ae0e6 100644 --- a/tests/filecheck/schedules/test_matmul_descript_mlir.py +++ b/tests/filecheck/schedules/test_matmul_descript_mlir.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "K": {}, "I": {}, diff --git a/tests/filecheck/schedules/test_matmul_descript_tvm.py b/tests/filecheck/schedules/test_matmul_descript_tvm.py index fab85fc7..c484775a 100644 --- a/tests/filecheck/schedules/test_matmul_descript_tvm.py +++ b/tests/filecheck/schedules/test_matmul_descript_tvm.py @@ -21,7 +21,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "K": {}, "I": {},