From b5214d26db7253874e38b5189df064bd5964fc03 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Sun, 1 Feb 2026 20:51:01 +0100 Subject: [PATCH 1/8] descript: use a tree structure to represent a loop nest --- src/xtc/schedules/descript.py | 608 +++++++++++++++++++--------------- 1 file changed, 348 insertions(+), 260 deletions(-) diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 63ee7052..1d860cd0 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -4,7 +4,7 @@ # from __future__ import annotations -from typing import Any +from typing import Any, Generic, TypeVar from dataclasses import dataclass, field import re from typing_extensions import override @@ -100,9 +100,6 @@ class ScheduleParser: _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$") - def __init__(self, abstract_axis: list[str]): - self.abstract_axis = abstract_axis - def parse(self, spec: dict[str, Any]) -> ScheduleSpec: """Parse a schedule specification dict into an AST.""" items: list[ScheduleItem] = [] @@ -216,195 +213,119 @@ def _parse_split_syntax( return prefix, x, y -class ScheduleInterpreter: - """Interprets a parsed ScheduleSpec AST into a LoopNest.""" - - def __init__(self, abstract_axis: list[str]): - self.abstract_axis = abstract_axis - self.root_to_dim: dict[str, str] = {} - self.dim_to_axis: dict[str, str] = {} - - def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest: - """Interpret a schedule specification into a LoopNest.""" - return self._interpret_spec(spec, root) - - def _interpret_spec(self, spec: ScheduleSpec, root: str) -> LoopNest: - """Interpret a schedule spec recursively.""" - loop_nest = LoopNest(abstract_dims=self.abstract_axis) - slice = loop_nest.build_slice(root) - - # Track state during interpretation - sizes: dict[str, int] = {} - previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} - interchange: list[str] = [] - # Only the first root is not in root_to_dim - if root in self.root_to_dim: - interchange.append(self.root_to_dim[root]) - - for item in spec.items: - if isinstance(item, SplitDecl): - self._interpret_split( - item, slice, loop_nest, root, interchange, previous_cut - ) - elif isinstance(item, TileDecl): - loop_name = self._interpret_tile(item, slice, interchange, sizes) - self._apply_annotations(item.annotations, loop_name, sizes, slice) - elif isinstance(item, AxisDecl): - loop_name = self._interpret_axis(item, interchange) - self._apply_annotations(item.annotations, loop_name, sizes, slice) - # Check that all splits are complete - for axis, cut in previous_cut.items(): - if cut is not None and cut != 0: - raise ScheduleInterpretError( - f"Splitting of {axis} unachieved (stops at {cut})." - ) - - slice.interchange = interchange - return loop_nest +@dataclass +class SplitOrigin: + """Describes how a node was created via a split from its parent. - def _interpret_split( - self, - item: SplitDecl, - slice: LoopNestSlice, - loop_nest: LoopNest, - root: str, - interchange: list[str], - previous_cut: dict[str, int | None], - ) -> None: - """Interpret a split declaration.""" - axis_name = item.axis - self._check_axis_existence(axis_name) - x = item.start - y = item.end + Attributes: + axis: The axis that was split to create this node. + start: The starting position of the split (inclusive), or None if unbounded. + end: The ending position of the split (exclusive), or None if unbounded. + """ - # The only declaration where y (the cut) is None is the - # last one, so it cannot be the previous one. - cut = previous_cut[axis_name] + axis: str + start: int | None + end: int | None - # When x (the starting point of the slice) is not specified, - # it is the previous cut - if x is None: - x = cut - assert x is not None - self._check_splitting_intervals(item, cut, x) +NodeT = TypeVar("NodeT", bound="Node") - # Update the previous cut - previous_cut[axis_name] = y - # Save the cutting points of the new dimensions - if axis_name not in slice.splits: - slice.splits[axis_name] = {} - new_dim_index = len(slice.splits[axis_name]) - new_dim_name = f"{axis_name}{SPLIT_LEFT_SEP}{new_dim_index}{SPLIT_RIGHT_SEP}" - new_root_name = f"{root}{ROOT_SEP}{new_dim_name}" - slice.splits[axis_name][new_dim_name] = x - interchange.append(new_dim_name) - self.dim_to_axis[new_dim_name] = axis_name - self.root_to_dim[new_root_name] = new_dim_name - # Recursively interpret the nested schedule - inner_nest = self._interpret_spec(item.body, new_root_name) - loop_nest.slices += inner_nest.slices +@dataclass(kw_only=True) +class Node(Generic[NodeT]): + """Base class for tree nodes with parent/child relationships. - def _interpret_tile( - self, - item: TileDecl, - slice: LoopNestSlice, - interchange: list[str], - sizes: dict[str, int], - ) -> str: - """Interpret a tile declaration. Returns the loop name.""" - self._check_axis_existence(item.axis) - tile_num = len(slice.tiles[item.axis]) - loop_name = f"{item.axis}{tile_num}" - if item.size <= 0: - raise ScheduleInterpretError( - f"`{item}`: tile sizes should be strictly positive." - ) - slice.tiles[item.axis][loop_name] = item.size - sizes[loop_name] = item.size - interchange.append(loop_name) - return loop_name + Provides tree structure and traversal operations. Subclasses add + domain-specific data. - def _interpret_axis( - self, - item: AxisDecl, - interchange: list[str], - ) -> str: - """Interpret a direct axis reference. Returns the loop name.""" - axis_name = item.axis - self._check_axis_existence(axis_name) - # Unreachable when built from a Python dict (because keys - # can't be duplicated). - for loop_name in interchange: - if self.dim_to_axis.get(loop_name, loop_name) == axis_name: - raise ScheduleInterpretError( - f"Axis {axis_name} is scheduled twice (or more)." - ) + Attributes: + parent: Reference to the parent node, or None for the root. + split_origin: Metadata describing how this node was created from + its parent via a split. None for the root node. + children: List of child nodes. + """ - interchange.append(axis_name) - return axis_name + parent: NodeT | None = None + split_origin: SplitOrigin | None = None + children: list[NodeT] = field(default_factory=list) - def _check_axis_existence(self, axis: str) -> None: - """Check that an axis is defined.""" - if axis not in self.abstract_axis: - raise ScheduleInterpretError( - f"Axis {axis} is not a defined axis (defined axis: {self.abstract_axis})." - ) + @property + def is_root(self) -> bool: + """Returns True if this node is the root (has no parent).""" + return self.parent is None + + def add_child(self, child: NodeT) -> None: + """Add a child node and set its parent to this node.""" + child.parent = self # type: ignore[assignment] + self.children.append(child) + + def ancestors(self) -> list[NodeT]: + """Return list of ancestors from parent to root.""" + result: list[NodeT] = [] + current = self.parent + while current is not None: + result.append(current) + current = current.parent + return result + + def descendants_dfs(self) -> list[NodeT]: + """Return all descendants in depth-first order.""" + result: list[NodeT] = [] + for child in self.children: + result.append(child) + result.extend(child.descendants_dfs()) + return result - def _apply_annotations( - self, - annotations: Annotations, - loop_name: str, - sizes: dict[str, int], - slice: LoopNestSlice, - ) -> None: - """Apply annotations to a loop in the slice.""" - if annotations.unroll_specified: - unroll_factor = annotations.unroll_factor - if unroll_factor is None: - # None means "unroll fully" - use the loop size - if loop_name not in sizes: - raise ScheduleInterpretError( - f"{loop_name}'s size being unknown, an unroll factor is needed." - ) - unroll_factor = sizes[loop_name] - elif unroll_factor <= 0: - raise ScheduleInterpretError( - f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' - ) - slice.unroll[loop_name] = unroll_factor - if annotations.vectorize: - slice.vectorize.append(loop_name) +@dataclass +class LoopNestNode(Node["LoopNestNode"]): + """Represents a node in the loop nest tree with its transformations. - if annotations.parallelize: - slice.parallelize.append(loop_name) + Describes the loops attached to a single root and + contains all the scheduling transformations applied to these loops. - def _check_splitting_intervals( - self, - item: SplitDecl, - cut: int | None, - x: int, - ) -> None: - """Check that split intervals are valid and contiguous.""" + Attributes: + root: Identifier of the node (either the base operation or + the content of a split). + tiles: Tiling configuration per axis. Maps axis names to dicts of + tile loop names and their sizes. + splits: Split configuration per axis. Maps axis names to dicts of + split loop names and their starting positions. + interchange: Ordered list of loop names defining the loop order. + vectorize: List of loops to vectorize. + parallelize: List of loops to parallelize. + unroll: Maps loop names to their unroll factors. + """ - if cut is None: - raise ScheduleInterpretError(f"{item}: {item.axis} already covered.") + root: str + tiles: dict[str, dict[str, int]] + splits: dict[str, dict[str, int | None]] = field(default_factory=dict) + interchange: list[str] = field(default_factory=list) + vectorize: list[str] = field(default_factory=list) + parallelize: list[str] = field(default_factory=list) + unroll: dict[str, int] = field(default_factory=dict) - if x > cut: - raise ScheduleInterpretError( - f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." - ) - elif x < cut: - raise ScheduleInterpretError( - f"{item}: the segment begins at {x} but the previous one ends at {cut}." - ) + @property + def splits_to_sizes(self) -> dict[str, int | None]: + splits_to_sizes: dict[str, int | None] = {} + for axis in self.splits: + last_start = None + for loop_name, start in reversed(self.splits[axis].items()): + if last_start is not None and start is not None: + size_of_split = last_start - start + splits_to_sizes[loop_name] = size_of_split + else: + splits_to_sizes[loop_name] = None + last_start = start + return splits_to_sizes - if item.end is not None and x >= item.end: - raise ScheduleInterpretError( - f"{item}: the ending point should be greater than the starting point." - ) + @property + def tiles_to_sizes(self) -> dict[str, int]: + tiles_to_sizes: dict[str, int] = {} + for tiles in self.tiles.values(): + for loop, size in tiles.items(): + tiles_to_sizes[loop] = size + return tiles_to_sizes @dataclass @@ -433,17 +354,17 @@ def loops_to_axis(self) -> dict[str, str]: return loops_to_axis @staticmethod - def build_from_slices(slices: list["LoopNestSlice"]) -> "LoopsDimsMapper": + def build_from_nodes(nodes: list[LoopNestNode]) -> LoopsDimsMapper: tiles_to_axis = {} splits_to_axis = {} dims = set() - for slice in slices: - tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.tiles)) - splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(slice.splits)) + for node in nodes: + tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.tiles)) + splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.splits)) refined_loops = list(tiles_to_axis) + list(splits_to_axis) - for slice in slices: + for node in nodes: dims.update( - [loop for loop in slice.interchange if loop not in refined_loops] + [loop for loop in node.interchange if loop not in refined_loops] ) dims.update(tiles_to_axis.values()) dims.update(splits_to_axis.values()) @@ -458,80 +379,41 @@ def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str] return loop_to_axis -@dataclass -class LoopNestSlice: - """Represents a single slice of a loop nest with its transformations. - - A slice describes the loops attached to a single root and - contains all the scheduling transformations applied to these loops. - - Attributes: - root: Identifier of the the slice (either the base operation or - the content of a split). - tiles: Tiling configuration per axis. Maps axis names to dicts of - tile loop names and their sizes. - splits: Split configuration per axis. Maps axis names to dicts of - split loop names and their starting positions. - interchange: Ordered list of loop names defining the loop order. - vectorize: List of loops to vectorize. - parallelize: List of loops to parallelize. - unroll: Maps loop names to their unroll factors. - """ - - root: str - tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int | None]] = field(default_factory=dict) - interchange: list[str] = field(default_factory=list) - vectorize: list[str] = field(default_factory=list) - parallelize: list[str] = field(default_factory=list) - unroll: dict[str, int] = field(default_factory=dict) - - @property - def splits_to_sizes(self) -> dict[str, int | None]: - splits_to_sizes: dict[str, int | None] = {} - for axis in self.splits: - last_start = None - for loop_name, start in reversed(self.splits[axis].items()): - if last_start is not None and start is not None: - size_of_split = last_start - start - splits_to_sizes[loop_name] = size_of_split - else: - splits_to_sizes[loop_name] = None - last_start = start - return splits_to_sizes - - @property - def tiles_to_sizes(self) -> dict[str, int]: - tiles_to_sizes: dict[str, int] = {} - for tiles in self.tiles.values(): - for loop, size in tiles.items(): - tiles_to_sizes[loop] = size - return tiles_to_sizes - - @dataclass class LoopNest: """Represents a complete loop nest structure for scheduling. - A loop nest contains abstract dimensions and a collection of slices/ - It provides validation to ensure consistency across all slices. + A loop nest contains abstract dimensions and a tree of nodes representing + the schedule. Splits create child nodes, forming an explicit tree structure. Attributes: abstract_dims: List of abstract dimension names for the loop nest. - slices: List of LoopNestSlice objects, one per scheduled operation. + root_node: The root node of the loop nest tree, or None if empty. """ abstract_dims: list[str] - slices: list[LoopNestSlice] = field(default_factory=list) + root_node: LoopNestNode | None = None + + @property + def empty(self) -> bool: + return self.root_node is None @property - def empty(self): - return not self.slices + def nodes(self) -> list[LoopNestNode]: + """Flatten the tree into a list of nodes. - def build_slice(self, root: str) -> LoopNestSlice: - slice = LoopNestSlice(root=root, tiles={a: {} for a in self.abstract_dims}) - self.slices = [slice] + self.slices - return slice + Returns nodes in depth-first order, with the root node first, + followed by children in the order they were created. + """ + if self.root_node is None: + return [] + return [self.root_node] + self.root_node.descendants_dfs() + + def build_root_node(self, root: str) -> LoopNestNode: + """Build and set the root node of the loop nest tree.""" + node = LoopNestNode(root=root, tiles={a: {} for a in self.abstract_dims}) + self.root_node = node + return node def check(self): self._check_use_defined_dims() @@ -540,13 +422,13 @@ def check(self): self._check_sizes() def _check_use_defined_dims(self): - mapper = LoopsDimsMapper.build_from_slices(self.slices) + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) for dim in self.abstract_dims: if dim not in mapper.dims: raise ScheduleValidationError(f"{dim} defined but never used") def _check_vectorization_consistency(self): - for sched in self.slices: + for sched in self.nodes: vect_above = False for loop_name in sched.interchange: if loop_name in sched.vectorize: @@ -557,9 +439,9 @@ def _check_vectorization_consistency(self): ) def _check_tiling_consistency(self) -> None: - mapper = LoopsDimsMapper.build_from_slices(self.slices) + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) seen_axes: dict[str, int | None] = {} - for sched in self.slices: + for sched in self.nodes: for loop_name in sched.interchange: loop_name = mapper.splits_to_axis.get(loop_name, loop_name) @@ -577,9 +459,9 @@ def _check_tiling_consistency(self) -> None: seen_axes[axis] = sched.tiles[axis][loop_name] def _check_sizes(self): - mapper = LoopsDimsMapper.build_from_slices(self.slices) + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) current_size_of_split: dict[str, int | None] = {} - for sched in self.slices: + for sched in self.nodes: current_size_of_tile: dict[str, int] = {} for loop_name in sched.interchange: axis = mapper.loops_to_axis[loop_name] @@ -659,6 +541,204 @@ def descript_scheduler( descript.apply(node_name=node_name, spec=spec) +class ScheduleInterpreter: + """Interprets a parsed ScheduleSpec AST into a LoopNest.""" + + def __init__(self, abstract_axis: list[str]): + self.abstract_axis = abstract_axis + + def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest: + """Interpret a schedule specification into a LoopNest.""" + loop_nest = LoopNest(abstract_dims=self.abstract_axis) + root_node = loop_nest.build_root_node(root) + self._interpret_spec_into_node(spec, root_node, root, head=[]) + return loop_nest + + def _interpret_spec_into_node( + self, + spec: ScheduleSpec, + node: LoopNestNode, + root: str, + head: list[str], + ) -> None: + """Interpret a schedule spec into an existing node (mutates node).""" + # Track state during interpretation + sizes: dict[str, int] = {} + previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} + interchange: list[str] = list(head) + + for item in spec.items: + if isinstance(item, SplitDecl): + self._interpret_split(item, node, root, interchange, previous_cut) + elif isinstance(item, TileDecl): + loop_name = self._interpret_tile(item, node, interchange, sizes) + self._apply_annotations(item.annotations, loop_name, sizes, node) + elif isinstance(item, AxisDecl): + loop_name = self._interpret_axis(item, interchange) + self._apply_annotations(item.annotations, loop_name, sizes, node) + + # Check that all splits are complete + for axis, cut in previous_cut.items(): + if cut is not None and cut != 0: + raise ScheduleInterpretError( + f"Splitting of {axis} unachieved (stops at {cut})." + ) + + node.interchange = interchange + + def _interpret_split( + self, + item: SplitDecl, + node: LoopNestNode, + root: str, + interchange: list[str], + previous_cut: dict[str, int | None], + ) -> None: + """Interpret a split declaration.""" + axis_name = item.axis + self._check_axis_existence(axis_name) + x = item.start + y = item.end + + # The only declaration where y (the cut) is None is the + # last one, so it cannot be the previous one. + cut = previous_cut[axis_name] + + # When x (the starting point of the split) is not specified, + # it is the previous cut + if x is None: + x = cut + assert x is not None + + self._check_splitting_intervals(item, cut, x) + + # Update the previous cut + previous_cut[axis_name] = y + + # Save the cutting points of the new dimensions + if axis_name not in node.splits: + node.splits[axis_name] = {} + new_dim_index = len(node.splits[axis_name]) + new_dim_name = f"{axis_name}{SPLIT_LEFT_SEP}{new_dim_index}{SPLIT_RIGHT_SEP}" + new_root_name = f"{root}{ROOT_SEP}{new_dim_name}" + node.splits[axis_name][new_dim_name] = x + interchange.append(new_dim_name) + + # Create a child node for the nested schedule + child_node = LoopNestNode( + root=new_root_name, + tiles={a: {} for a in self.abstract_axis}, + split_origin=SplitOrigin(axis=axis_name, start=x, end=y), + ) + node.add_child(child_node) + + # Recursively interpret the nested schedule into the child node + self._interpret_spec_into_node( + item.body, child_node, new_root_name, head=[axis_name] + ) + + def _interpret_tile( + self, + item: TileDecl, + node: LoopNestNode, + interchange: list[str], + sizes: dict[str, int], + ) -> str: + """Interpret a tile declaration. Returns the loop name.""" + self._check_axis_existence(item.axis) + tile_num = len(node.tiles[item.axis]) + loop_name = f"{item.axis}{tile_num}" + if item.size <= 0: + raise ScheduleInterpretError( + f"`{item}`: tile sizes should be strictly positive." + ) + node.tiles[item.axis][loop_name] = item.size + sizes[loop_name] = item.size + interchange.append(loop_name) + + return loop_name + + def _interpret_axis( + self, + item: AxisDecl, + interchange: list[str], + ) -> str: + """Interpret a direct axis reference. Returns the loop name.""" + axis_name = item.axis + self._check_axis_existence(axis_name) + + # Unreachable when built from a Python dict (because keys + # can't be duplicated). + if axis_name in interchange: + raise ScheduleInterpretError( + f"Axis {axis_name} is scheduled twice (or more)." + ) + + interchange.append(axis_name) + return axis_name + + def _check_axis_existence(self, axis: str) -> None: + """Check that an axis is defined.""" + if axis not in self.abstract_axis: + raise ScheduleInterpretError( + f"Axis {axis} is not a defined axis (defined axis: {self.abstract_axis})." + ) + + def _apply_annotations( + self, + annotations: Annotations, + loop_name: str, + sizes: dict[str, int], + node: LoopNestNode, + ) -> None: + """Apply annotations to a loop in the node.""" + if annotations.unroll_specified: + unroll_factor = annotations.unroll_factor + if unroll_factor is None: + # None means "unroll fully" - use the loop size + if loop_name not in sizes: + raise ScheduleInterpretError( + f"{loop_name}'s size being unknown, an unroll factor is needed." + ) + unroll_factor = sizes[loop_name] + elif unroll_factor <= 0: + raise ScheduleInterpretError( + f'`{{"unroll" = {unroll_factor}}}`: unroll parameter should be strictly positive.' + ) + node.unroll[loop_name] = unroll_factor + + if annotations.vectorize: + node.vectorize.append(loop_name) + + if annotations.parallelize: + node.parallelize.append(loop_name) + + def _check_splitting_intervals( + self, + item: SplitDecl, + cut: int | None, + x: int, + ) -> None: + """Check that split intervals are valid and contiguous.""" + + if cut is None: + raise ScheduleInterpretError(f"{item}: {item.axis} already covered.") + + if x > cut: + raise ScheduleInterpretError( + f"{item}: splitting doesn't fully cover {item.axis} (jumps from {cut} to {x})." + ) + elif x < cut: + raise ScheduleInterpretError( + f"{item}: the segment begins at {x} but the previous one ends at {cut}." + ) + + if item.end is not None and x >= item.end: + raise ScheduleInterpretError( + f"{item}: the ending point should be greater than the starting point." + ) + + @dataclass(frozen=True) class Descript: """Applies a parsed and interpreted schedule to a Scheduler. @@ -687,7 +767,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: ScheduleValidationError: If the resulting schedule is invalid. """ # Parse the specification into an AST - parser = ScheduleParser(self.abstract_axis) + parser = ScheduleParser() ast = parser.parse(spec) # Interpret the AST into a LoopNest @@ -702,16 +782,24 @@ def _apply_loop_nest(self, loop_nest: LoopNest) -> None: """Apply a LoopNest to the scheduler.""" self.scheduler.set_dims(self.abstract_axis) - for slice in loop_nest.slices: - root = slice.root + if loop_nest.root_node is not None: + self._apply_node(loop_nest.root_node) + + def _apply_node(self, node: LoopNestNode) -> None: + """Recursively apply a LoopNestNode and its children to the scheduler.""" + root = node.root + + for d, s in node.splits.items(): + self.scheduler.split(d, s, root=root) - for d, s in slice.splits.items(): - self.scheduler.split(d, s, root=root) + for d, s in node.tiles.items(): + self.scheduler.tile(d, s, root=root) - for d, s in slice.tiles.items(): - self.scheduler.tile(d, s, root=root) + self.scheduler.interchange(node.interchange, root=root) + self.scheduler.vectorize(node.vectorize, root=root) + self.scheduler.parallelize(node.parallelize, root=root) + self.scheduler.unroll(node.unroll, root=root) - self.scheduler.interchange(slice.interchange, root=root) - self.scheduler.vectorize(slice.vectorize, root=root) - self.scheduler.parallelize(slice.parallelize, root=root) - self.scheduler.unroll(slice.unroll, root=root) + # Recursively apply children + for child in node.children: + self._apply_node(child) From 1c4ca8c90576d5e68671bca71470c312fa5173c0 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Sun, 1 Feb 2026 20:57:37 +0100 Subject: [PATCH 2/8] descript: split the exceptions, the parsing, the interpretation and the data structure --- src/xtc/schedules/descript.py | 527 +------------------------------- src/xtc/schedules/exceptions.py | 23 ++ src/xtc/schedules/loop_nest.py | 310 +++++++++++++++++++ src/xtc/schedules/parsing.py | 196 ++++++++++++ 4 files changed, 541 insertions(+), 515 deletions(-) create mode 100644 src/xtc/schedules/exceptions.py create mode 100644 src/xtc/schedules/loop_nest.py create mode 100644 src/xtc/schedules/parsing.py diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 1d860cd0..026891fc 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -2,523 +2,20 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from __future__ import annotations +from typing import Any +from dataclasses import dataclass -from typing import Any, Generic, TypeVar -from dataclasses import dataclass, field -import re -from typing_extensions import override from xtc.itf.schd.scheduler import Scheduler, ROOT_SEP, SPLIT_LEFT_SEP, SPLIT_RIGHT_SEP - - -class ScheduleParseError(RuntimeError): - """Raised when schedule parsing fails.""" - - pass - - -class ScheduleInterpretError(RuntimeError): - """Raised when schedule interpretation fails.""" - - pass - - -class ScheduleValidationError(RuntimeError): - """Raised when schedule validation fails.""" - - pass - - -@dataclass(frozen=True) -class Annotations: - """AST Type : annotations that can be applied to a loop. - - Attributes: - unroll_factor: The unroll factor. None means "unroll fully" (use loop size). - Only meaningful when unroll_specified is True. - unroll_specified: True if unroll was explicitly requested. - vectorize: True if vectorization was requested. - parallelize: True if parallelization was requested. - """ - - unroll_factor: int | None = None - unroll_specified: bool = False - vectorize: bool = False - parallelize: bool = False - - -@dataclass(frozen=True) -class SplitDecl: - """AST Type: a split declaration like 'axis[start:end]'.""" - - axis: str - start: int | None - end: int | None - body: ScheduleSpec - - @override - def __str__(self) -> str: - start_str = "" if self.start is None else str(self.start) - end_str = "" if self.end is None else str(self.end) - decl = f"{self.axis}{SPLIT_LEFT_SEP}{start_str}:{end_str}{SPLIT_RIGHT_SEP}" - return decl - - -@dataclass(frozen=True) -class TileDecl: - """AST Type: a tile declaration like 'axis#size'.""" - - axis: str - size: int - annotations: Annotations - - @override - def __str__(self) -> str: - return f"{self.axis}#{self.size}" - - -@dataclass(frozen=True) -class AxisDecl: - """AST Type: a direct axis reference.""" - - axis: str - annotations: Annotations - - -ScheduleItem = SplitDecl | TileDecl | AxisDecl - - -@dataclass(frozen=True) -class ScheduleSpec: - """AST Type: the complete parsed schedule specification.""" - - items: tuple[ScheduleItem, ...] - - -class ScheduleParser: - """Parses a dict-based schedule specification into an AST.""" - - _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$") - - def parse(self, spec: dict[str, Any]) -> ScheduleSpec: - """Parse a schedule specification dict into an AST.""" - items: list[ScheduleItem] = [] - - for declaration, value in spec.items(): - item = self._parse_declaration(declaration, value) - items.append(item) - - return ScheduleSpec(items=tuple(items)) - - def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: - """Parse a single declaration into a ScheduleItem.""" - assert isinstance(value, dict) - # Try split declaration first (e.g., "axis[0:10]") - if ":" in declaration: - return self._parse_split(declaration, value) - - # Try tile declaration (e.g., "axis#32") - if "#" in declaration: - return self._parse_tile(declaration, value) - - # Must be a direct axis reference - return self._parse_axis_ref(declaration, value) - - def _parse_split(self, declaration: str, value: dict) -> SplitDecl: - """Parse a split declaration like 'axis[start:end]'.""" - axis_name, start, end = self._parse_split_syntax(declaration) - - body = self.parse(value) - return SplitDecl(axis=axis_name, start=start, end=end, body=body) - - def _parse_tile(self, declaration: str, value: dict) -> TileDecl: - """Parse a tile declaration like 'axis#size'.""" - parts = declaration.split("#") - if len(parts) != 2: - raise ScheduleParseError( - f"`{declaration}`: invalid tile syntax, expected 'axis#size'" - ) - - axis_name, size_str = parts - - try: - size = int(size_str) - except ValueError: - raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") - - annotations = self._parse_annotations(value, declaration) - return TileDecl(axis=axis_name, size=size, annotations=annotations) - - def _parse_axis_ref(self, declaration: str, value: dict) -> AxisDecl: - """Parse a direct axis reference.""" - - annotations = self._parse_annotations(value, declaration) - return AxisDecl(axis=declaration, annotations=annotations) - - def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: - """Parse annotation dict into Annotations object.""" - - unroll_factor: int | None = None - unroll_specified = False - vectorize = False - parallelize = False - - for key, param in value.items(): - if key == "unroll": - if param is True: - unroll_factor = None - unroll_specified = True - elif param is False: - pass - elif isinstance(param, int): - unroll_factor = param - unroll_specified = True - else: - raise ScheduleParseError( - f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, or an integer.' - ) - elif key == "vectorize": - if not isinstance(param, bool): - raise ScheduleParseError( - f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' - ) - vectorize = param - elif key == "parallelize": - if not isinstance(param, bool): - raise ScheduleParseError( - f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' - ) - parallelize = param - else: - raise ScheduleParseError(f"Unknown annotation on {context}: {key}") - - return Annotations( - unroll_factor=unroll_factor, - unroll_specified=unroll_specified, - vectorize=vectorize, - parallelize=parallelize, - ) - - def _parse_split_syntax( - self, declaration: str - ) -> tuple[str, int | None, int | None]: - """Parse the syntax of a split declaration.""" - match = self._SPLIT_PATTERN.match(declaration) - if not match: - raise ScheduleParseError(f"Wrong format {declaration}") - - prefix, x_str, y_str = match.groups() - x = int(x_str) if x_str else None - y = int(y_str) if y_str else None - return prefix, x, y - - -@dataclass -class SplitOrigin: - """Describes how a node was created via a split from its parent. - - Attributes: - axis: The axis that was split to create this node. - start: The starting position of the split (inclusive), or None if unbounded. - end: The ending position of the split (exclusive), or None if unbounded. - """ - - axis: str - start: int | None - end: int | None - - -NodeT = TypeVar("NodeT", bound="Node") - - -@dataclass(kw_only=True) -class Node(Generic[NodeT]): - """Base class for tree nodes with parent/child relationships. - - Provides tree structure and traversal operations. Subclasses add - domain-specific data. - - Attributes: - parent: Reference to the parent node, or None for the root. - split_origin: Metadata describing how this node was created from - its parent via a split. None for the root node. - children: List of child nodes. - """ - - parent: NodeT | None = None - split_origin: SplitOrigin | None = None - children: list[NodeT] = field(default_factory=list) - - @property - def is_root(self) -> bool: - """Returns True if this node is the root (has no parent).""" - return self.parent is None - - def add_child(self, child: NodeT) -> None: - """Add a child node and set its parent to this node.""" - child.parent = self # type: ignore[assignment] - self.children.append(child) - - def ancestors(self) -> list[NodeT]: - """Return list of ancestors from parent to root.""" - result: list[NodeT] = [] - current = self.parent - while current is not None: - result.append(current) - current = current.parent - return result - - def descendants_dfs(self) -> list[NodeT]: - """Return all descendants in depth-first order.""" - result: list[NodeT] = [] - for child in self.children: - result.append(child) - result.extend(child.descendants_dfs()) - return result - - -@dataclass -class LoopNestNode(Node["LoopNestNode"]): - """Represents a node in the loop nest tree with its transformations. - - Describes the loops attached to a single root and - contains all the scheduling transformations applied to these loops. - - Attributes: - root: Identifier of the node (either the base operation or - the content of a split). - tiles: Tiling configuration per axis. Maps axis names to dicts of - tile loop names and their sizes. - splits: Split configuration per axis. Maps axis names to dicts of - split loop names and their starting positions. - interchange: Ordered list of loop names defining the loop order. - vectorize: List of loops to vectorize. - parallelize: List of loops to parallelize. - unroll: Maps loop names to their unroll factors. - """ - - root: str - tiles: dict[str, dict[str, int]] - splits: dict[str, dict[str, int | None]] = field(default_factory=dict) - interchange: list[str] = field(default_factory=list) - vectorize: list[str] = field(default_factory=list) - parallelize: list[str] = field(default_factory=list) - unroll: dict[str, int] = field(default_factory=dict) - - @property - def splits_to_sizes(self) -> dict[str, int | None]: - splits_to_sizes: dict[str, int | None] = {} - for axis in self.splits: - last_start = None - for loop_name, start in reversed(self.splits[axis].items()): - if last_start is not None and start is not None: - size_of_split = last_start - start - splits_to_sizes[loop_name] = size_of_split - else: - splits_to_sizes[loop_name] = None - last_start = start - return splits_to_sizes - - @property - def tiles_to_sizes(self) -> dict[str, int]: - tiles_to_sizes: dict[str, int] = {} - for tiles in self.tiles.values(): - for loop, size in tiles.items(): - tiles_to_sizes[loop] = size - return tiles_to_sizes - - -@dataclass -class LoopsDimsMapper: - """Maps loop names to their corresponding axis names. - - This class tracks the relationship between loop identifiers (from tiling - and splitting transformations) and the original dimension axes they - derive from. - - Attributes: - tiles_to_axis: Maps tile loop names to their parent axis. - splits_to_axis: Maps split loop names to their parent axis. - dims: List of original dimension names. - """ - - tiles_to_axis: dict[str, str] - splits_to_axis: dict[str, str] - dims: list[str] - - @property - def loops_to_axis(self) -> dict[str, str]: - loops_to_axis = ( - self.tiles_to_axis | self.splits_to_axis | dict(zip(self.dims, self.dims)) - ) - return loops_to_axis - - @staticmethod - def build_from_nodes(nodes: list[LoopNestNode]) -> LoopsDimsMapper: - tiles_to_axis = {} - splits_to_axis = {} - dims = set() - for node in nodes: - tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.tiles)) - splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.splits)) - refined_loops = list(tiles_to_axis) + list(splits_to_axis) - for node in nodes: - dims.update( - [loop for loop in node.interchange if loop not in refined_loops] - ) - dims.update(tiles_to_axis.values()) - dims.update(splits_to_axis.values()) - return LoopsDimsMapper(tiles_to_axis, splits_to_axis, list(dims)) - - @staticmethod - def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]: - loop_to_axis: dict[str, str] = {} - for axis_name, subloops in subloops.items(): - for loop_name in subloops: - loop_to_axis[loop_name] = axis_name - return loop_to_axis - - -@dataclass -class LoopNest: - """Represents a complete loop nest structure for scheduling. - - A loop nest contains abstract dimensions and a tree of nodes representing - the schedule. Splits create child nodes, forming an explicit tree structure. - - Attributes: - abstract_dims: List of abstract dimension names for the loop nest. - root_node: The root node of the loop nest tree, or None if empty. - """ - - abstract_dims: list[str] - root_node: LoopNestNode | None = None - - @property - def empty(self) -> bool: - return self.root_node is None - - @property - def nodes(self) -> list[LoopNestNode]: - """Flatten the tree into a list of nodes. - - Returns nodes in depth-first order, with the root node first, - followed by children in the order they were created. - """ - if self.root_node is None: - return [] - return [self.root_node] + self.root_node.descendants_dfs() - - def build_root_node(self, root: str) -> LoopNestNode: - """Build and set the root node of the loop nest tree.""" - node = LoopNestNode(root=root, tiles={a: {} for a in self.abstract_dims}) - self.root_node = node - return node - - def check(self): - self._check_use_defined_dims() - self._check_vectorization_consistency() - self._check_tiling_consistency() - self._check_sizes() - - def _check_use_defined_dims(self): - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) - for dim in self.abstract_dims: - if dim not in mapper.dims: - raise ScheduleValidationError(f"{dim} defined but never used") - - def _check_vectorization_consistency(self): - for sched in self.nodes: - vect_above = False - for loop_name in sched.interchange: - if loop_name in sched.vectorize: - vect_above = True - elif vect_above: - raise ScheduleValidationError( - f"Inner loop {loop_name} isn't vectorized but an outer one is." - ) - - def _check_tiling_consistency(self) -> None: - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) - seen_axes: dict[str, int | None] = {} - for sched in self.nodes: - for loop_name in sched.interchange: - loop_name = mapper.splits_to_axis.get(loop_name, loop_name) - - if loop_name in mapper.dims: - seen_axes[loop_name] = None - elif loop_name in mapper.tiles_to_axis: - axis = mapper.tiles_to_axis[loop_name] - size = sched.tiles_to_sizes[loop_name] - if axis not in seen_axes: - raise ScheduleValidationError( - f""" - `{axis}#{size}`: {axis} has not been materialized yet. - """ - ) - seen_axes[axis] = sched.tiles[axis][loop_name] - - def _check_sizes(self): - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) - current_size_of_split: dict[str, int | None] = {} - for sched in self.nodes: - current_size_of_tile: dict[str, int] = {} - for loop_name in sched.interchange: - axis = mapper.loops_to_axis[loop_name] - current_sizes = ( - {d: None for d in mapper.dims} - | current_size_of_split - | current_size_of_tile - ) - loop_size = None - if loop_name in mapper.dims: - if loop_name not in current_size_of_split: - current_size_of_split[loop_name] = None - elif loop_name in mapper.tiles_to_axis: - loop_size = sched.tiles[axis][loop_name] - LoopNest._must_be_smaller_routine( - new_size=loop_size, - current_sizes=current_sizes, - loop_name=loop_name, - axis=axis, - ) - current_size_of_tile[axis] = loop_size - elif ( - loop_name in mapper.splits_to_axis - and loop_name in sched.splits_to_sizes - ): - loop_size = sched.splits_to_sizes[loop_name] - LoopNest._must_be_smaller_routine( - new_size=loop_size, - current_sizes=current_sizes, - loop_name=loop_name, - axis=axis, - ) - current_size_of_split[loop_name] = loop_size - elif loop_name in current_size_of_split: - current_size_of_split[axis] = current_size_of_split[loop_name] - - if loop_name in sched.unroll: - unroll_factor = sched.unroll[loop_name] - if loop_size and loop_size < unroll_factor: - raise ScheduleValidationError( - f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' - ) - - @staticmethod - def _must_be_smaller_routine( - new_size: int | None, - current_sizes: dict[str, int | None], - loop_name: str, - axis: str, - ): - old_size = current_sizes[axis] - if old_size is not None and new_size is not None and new_size > old_size: - raise ScheduleValidationError( - f""" - Inner loop {loop_name} on axis {axis} must be smaller than outer loop. - """ - ) +from .exceptions import ScheduleInterpretError +from .parsing import ( + ScheduleParser, + ScheduleSpec, + SplitDecl, + TileDecl, + AxisDecl, + Annotations, +) +from .loop_nest import LoopNestNode, LoopNest, SplitOrigin def descript_scheduler( diff --git a/src/xtc/schedules/exceptions.py b/src/xtc/schedules/exceptions.py new file mode 100644 index 00000000..84d51912 --- /dev/null +++ b/src/xtc/schedules/exceptions.py @@ -0,0 +1,23 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +"""Schedule-related exceptions.""" + + +class ScheduleParseError(RuntimeError): + """Raised when schedule parsing fails.""" + + pass + + +class ScheduleInterpretError(RuntimeError): + """Raised when schedule interpretation fails.""" + + pass + + +class ScheduleValidationError(RuntimeError): + """Raised when schedule validation fails.""" + + pass diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py new file mode 100644 index 00000000..4826d9c8 --- /dev/null +++ b/src/xtc/schedules/loop_nest.py @@ -0,0 +1,310 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from __future__ import annotations + +from typing import Any, Generic, TypeVar +from dataclasses import dataclass, field + +from .exceptions import ScheduleValidationError + + +@dataclass +class SplitOrigin: + """Describes how a node was created via a split from its parent. + + Attributes: + axis: The axis that was split to create this node. + start: The starting position of the split (inclusive), or None if unbounded. + end: The ending position of the split (exclusive), or None if unbounded. + """ + + axis: str + start: int | None + end: int | None + + +NodeT = TypeVar("NodeT", bound="Node") + + +@dataclass(kw_only=True) +class Node(Generic[NodeT]): + """Base class for tree nodes with parent/child relationships. + + Provides tree structure and traversal operations. Subclasses add + domain-specific data. + + Attributes: + parent: Reference to the parent node, or None for the root. + split_origin: Metadata describing how this node was created from + its parent via a split. None for the root node. + children: List of child nodes. + """ + + parent: NodeT | None = None + split_origin: SplitOrigin | None = None + children: list[NodeT] = field(default_factory=list) + + @property + def is_root(self) -> bool: + """Returns True if this node is the root (has no parent).""" + return self.parent is None + + def add_child(self, child: NodeT) -> None: + """Add a child node and set its parent to this node.""" + child.parent = self # type: ignore[assignment] + self.children.append(child) + + def ancestors(self) -> list[NodeT]: + """Return list of ancestors from parent to root.""" + result: list[NodeT] = [] + current = self.parent + while current is not None: + result.append(current) + current = current.parent + return result + + def descendants_dfs(self) -> list[NodeT]: + """Return all descendants in depth-first order.""" + result: list[NodeT] = [] + for child in self.children: + result.append(child) + result.extend(child.descendants_dfs()) + return result + + +@dataclass +class LoopNestNode(Node["LoopNestNode"]): + """Represents a node in the loop nest tree with its transformations. + + Describes the loops attached to a single root and + contains all the scheduling transformations applied to these loops. + + Attributes: + root: Identifier of the node (either the base operation or + the content of a split). + tiles: Tiling configuration per axis. Maps axis names to dicts of + tile loop names and their sizes. + splits: Split configuration per axis. Maps axis names to dicts of + split loop names and their starting positions. + interchange: Ordered list of loop names defining the loop order. + vectorize: List of loops to vectorize. + parallelize: List of loops to parallelize. + unroll: Maps loop names to their unroll factors. + """ + + root: str + tiles: dict[str, dict[str, int]] + splits: dict[str, dict[str, int]] = field(default_factory=dict) + interchange: list[str] = field(default_factory=list) + vectorize: list[str] = field(default_factory=list) + parallelize: list[str] = field(default_factory=list) + unroll: dict[str, int] = field(default_factory=dict) + + @property + def splits_to_sizes(self) -> dict[str, int]: + splits_to_sizes: dict[str, int] = {} + for axis in self.splits: + last_start = None + for loop_name, start in reversed(self.splits[axis].items()): + if last_start is not None: + size_of_split = last_start - start + splits_to_sizes[loop_name] = size_of_split + last_start = start + return splits_to_sizes + + @property + def tiles_to_sizes(self) -> dict[str, int]: + tiles_to_sizes: dict[str, int] = {} + for tiles in self.tiles.values(): + for loop, size in tiles.items(): + tiles_to_sizes[loop] = size + return tiles_to_sizes + + +@dataclass +class LoopsDimsMapper: + """Maps loop names to their corresponding axis names. + + This class tracks the relationship between loop identifiers (from tiling + and splitting transformations) and the original dimension axes they + derive from. + + Attributes: + tiles_to_axis: Maps tile loop names to their parent axis. + splits_to_axis: Maps split loop names to their parent axis. + dims: List of original dimension names. + """ + + tiles_to_axis: dict[str, str] + splits_to_axis: dict[str, str] + dims: list[str] + + @property + def loops_to_axis(self) -> dict[str, str]: + loops_to_axis = ( + self.tiles_to_axis | self.splits_to_axis | dict(zip(self.dims, self.dims)) + ) + return loops_to_axis + + @staticmethod + def build_from_nodes(nodes: list[LoopNestNode]) -> LoopsDimsMapper: + tiles_to_axis = {} + splits_to_axis = {} + dims = set() + for node in nodes: + tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.tiles)) + splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.splits)) + refined_loops = list(tiles_to_axis) + list(splits_to_axis) + for node in nodes: + dims.update( + [loop for loop in node.interchange if loop not in refined_loops] + ) + dims.update(tiles_to_axis.values()) + dims.update(splits_to_axis.values()) + return LoopsDimsMapper(tiles_to_axis, splits_to_axis, list(dims)) + + @staticmethod + def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]: + loop_to_axis: dict[str, str] = {} + for axis_name, subloops in subloops.items(): + for loop_name in subloops: + loop_to_axis[loop_name] = axis_name + return loop_to_axis + + +@dataclass +class LoopNest: + """Represents a complete loop nest structure for scheduling. + + A loop nest contains abstract dimensions and a tree of nodes representing + the schedule. Splits create child nodes, forming an explicit tree structure. + + Attributes: + abstract_dims: List of abstract dimension names for the loop nest. + root_node: The root node of the loop nest tree, or None if empty. + """ + + abstract_dims: list[str] + root_node: LoopNestNode | None = None + + @property + def empty(self) -> bool: + return self.root_node is None + + @property + def nodes(self) -> list[LoopNestNode]: + """Flatten the tree into a list of nodes. + + Returns nodes in depth-first order, with the root node first, + followed by children in the order they were created. + """ + if self.root_node is None: + return [] + return [self.root_node] + self.root_node.descendants_dfs() + + def build_root_node(self, root: str) -> LoopNestNode: + """Build and set the root node of the loop nest tree.""" + node = LoopNestNode(root=root, tiles={a: {} for a in self.abstract_dims}) + self.root_node = node + return node + + def check(self): + self._check_use_defined_dims() + self._check_vectorization_consistency() + self._check_tiling_consistency() + self._check_sizes() + + def _check_use_defined_dims(self): + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + for dim in self.abstract_dims: + if dim not in mapper.dims: + raise ScheduleValidationError(f"{dim} defined but never used") + + def _check_vectorization_consistency(self): + for sched in self.nodes: + vect_above = False + for loop_name in sched.interchange: + if loop_name in sched.vectorize: + vect_above = True + elif vect_above: + raise ScheduleValidationError( + f"Inner loop {loop_name} isn't vectorized but an outer one is." + ) + + def _check_tiling_consistency(self) -> None: + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + seen_axes: dict[str, int | None] = {} + for sched in self.nodes: + for loop_name in sched.interchange: + if loop_name in mapper.dims: + seen_axes[loop_name] = None + elif loop_name in mapper.tiles_to_axis: + axis = mapper.tiles_to_axis[loop_name] + size = sched.tiles_to_sizes[loop_name] + if axis not in seen_axes: + raise ScheduleValidationError( + f""" + `{axis}#{size}`: {axis} has not been materialized yet. + """ + ) + seen_axes[axis] = sched.tiles[axis][loop_name] + + def _check_sizes(self): + mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + current_size_of_split: dict[str, int | None] = {} + for sched in self.nodes: + current_size_of_tile: dict[str, int] = {} + + for loop_name in sched.interchange: + axis = mapper.loops_to_axis[loop_name] + current_sizes = ( + {d: None for d in mapper.dims} + | current_size_of_split + | current_size_of_tile + ) + loop_size = None + if loop_name in mapper.dims: + if loop_name not in current_size_of_split: + current_size_of_split[loop_name] = None + elif loop_name in mapper.tiles_to_axis: + loop_size = sched.tiles[axis][loop_name] + LoopNest._must_be_smaller_routine( + new_size=loop_size, + current_sizes=current_sizes, + loop_name=loop_name, + axis=axis, + ) + current_size_of_tile[axis] = loop_size + elif ( + loop_name in mapper.splits_to_axis + and loop_name in sched.splits_to_sizes + ): + loop_size = sched.splits_to_sizes[loop_name] + LoopNest._must_be_smaller_routine( + new_size=loop_size, + current_sizes=current_sizes, + loop_name=loop_name, + axis=axis, + ) + current_size_of_split[axis] = loop_size + + if loop_name in sched.unroll: + unroll_factor = sched.unroll[loop_name] + if loop_size and loop_size < unroll_factor: + raise ScheduleValidationError( + f'`{{"unroll" = {unroll_factor}}}`: unroll factor should be smaller than {loop_size}.' + ) + + @staticmethod + def _must_be_smaller_routine( + new_size: int, current_sizes: dict[str, int | None], loop_name: str, axis: str + ): + old_size = current_sizes[axis] + if old_size is not None and new_size > old_size: + raise ScheduleValidationError( + f""" + Inner loop {loop_name} on axis {axis} must be smaller than outer loop. + """ + ) diff --git a/src/xtc/schedules/parsing.py b/src/xtc/schedules/parsing.py new file mode 100644 index 00000000..419feede --- /dev/null +++ b/src/xtc/schedules/parsing.py @@ -0,0 +1,196 @@ +# +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024-2026 The XTC Project Authors +# +from __future__ import annotations + +from typing import Any +from dataclasses import dataclass +import re +from typing_extensions import override + +from .exceptions import ScheduleParseError + + +@dataclass(frozen=True) +class Annotations: + """AST Type : annotations that can be applied to a loop. + + Attributes: + unroll_factor: The unroll factor. None means "unroll fully" (use loop size). + Only meaningful when unroll_specified is True. + unroll_specified: True if unroll was explicitly requested. + vectorize: True if vectorization was requested. + parallelize: True if parallelization was requested. + """ + + unroll_factor: int | None = None + unroll_specified: bool = False + vectorize: bool = False + parallelize: bool = False + + +@dataclass(frozen=True) +class SplitDecl: + """AST Type: a split declaration like 'axis[start:end]'.""" + + axis: str + start: int | None + end: int | None + body: ScheduleSpec + + @override + def __str__(self) -> str: + start_str = "" if self.start is None else str(self.start) + end_str = "" if self.end is None else str(self.end) + decl = f"{self.axis}[{start_str}:{end_str}]" + return decl + + +@dataclass(frozen=True) +class TileDecl: + """AST Type: a tile declaration like 'axis#size'.""" + + axis: str + size: int + annotations: Annotations + + @override + def __str__(self) -> str: + return f"{self.axis}#{self.size}" + + +@dataclass(frozen=True) +class AxisDecl: + """AST Type: a direct axis reference.""" + + axis: str + annotations: Annotations + + +ScheduleItem = SplitDecl | TileDecl | AxisDecl + + +@dataclass(frozen=True) +class ScheduleSpec: + """AST Type: the complete parsed schedule specification.""" + + items: tuple[ScheduleItem, ...] + + +class ScheduleParser: + """Parses a dict-based schedule specification into an AST.""" + + _SPLIT_PATTERN = re.compile(r"^(.*)\[(-\d+|\d*)?:(-\d+|\d*)?\]$") + + def parse(self, spec: dict[str, Any]) -> ScheduleSpec: + """Parse a schedule specification dict into an AST.""" + items: list[ScheduleItem] = [] + + for declaration, value in spec.items(): + item = self._parse_declaration(declaration, value) + items.append(item) + + return ScheduleSpec(items=tuple(items)) + + def _parse_declaration(self, declaration: str, value: Any) -> ScheduleItem: + """Parse a single declaration into a ScheduleItem.""" + assert isinstance(value, dict) + # Try split declaration first (e.g., "axis[0:10]") + if ":" in declaration: + return self._parse_split(declaration, value) + + # Try tile declaration (e.g., "axis#32") + if "#" in declaration: + return self._parse_tile(declaration, value) + + # Must be a direct axis reference + return self._parse_axis_ref(declaration, value) + + def _parse_split(self, declaration: str, value: dict) -> SplitDecl: + """Parse a split declaration like 'axis[start:end]'.""" + axis_name, start, end = self._parse_split_syntax(declaration) + + body = self.parse(value) + return SplitDecl(axis=axis_name, start=start, end=end, body=body) + + def _parse_tile(self, declaration: str, value: dict) -> TileDecl: + """Parse a tile declaration like 'axis#size'.""" + parts = declaration.split("#") + if len(parts) != 2: + raise ScheduleParseError( + f"`{declaration}`: invalid tile syntax, expected 'axis#size'" + ) + + axis_name, size_str = parts + + try: + size = int(size_str) + except ValueError: + raise ScheduleParseError(f"`{declaration}`: {size_str} is not an integer.") + + annotations = self._parse_annotations(value, declaration) + return TileDecl(axis=axis_name, size=size, annotations=annotations) + + def _parse_axis_ref(self, declaration: str, value: dict) -> AxisDecl: + """Parse a direct axis reference.""" + + annotations = self._parse_annotations(value, declaration) + return AxisDecl(axis=declaration, annotations=annotations) + + def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations: + """Parse annotation dict into Annotations object.""" + + unroll_factor: int | None = None + unroll_specified = False + vectorize = False + parallelize = False + + for key, param in value.items(): + if key == "unroll": + if param is True: + unroll_factor = None + unroll_specified = True + elif param is False: + pass + elif isinstance(param, int): + unroll_factor = param + unroll_specified = True + else: + raise ScheduleParseError( + f'`{{"unroll" = {param}}}`: unroll parameter should be True, False, or an integer.' + ) + elif key == "vectorize": + if not isinstance(param, bool): + raise ScheduleParseError( + f'`{{"vectorize" = {param}}}`: parameterized vectorization not implemented.' + ) + vectorize = param + elif key == "parallelize": + if not isinstance(param, bool): + raise ScheduleParseError( + f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' + ) + parallelize = param + else: + raise ScheduleParseError(f"Unknown annotation on {context}: {key}") + + return Annotations( + unroll_factor=unroll_factor, + unroll_specified=unroll_specified, + vectorize=vectorize, + parallelize=parallelize, + ) + + def _parse_split_syntax( + self, declaration: str + ) -> tuple[str, int | None, int | None]: + """Parse the syntax of a split declaration.""" + match = self._SPLIT_PATTERN.match(declaration) + if not match: + raise ScheduleParseError(f"Wrong format {declaration}") + + prefix, x_str, y_str = match.groups() + x = int(x_str) if x_str else None + y = int(y_str) if y_str else None + return prefix, x, y From d82cbf7c42beb99c496a64625410c40e0f9371b4 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 2 Feb 2026 01:45:14 +0100 Subject: [PATCH 3/8] descript: pretty printer --- src/xtc/schedules/loop_nest.py | 121 ++++++++++++++++++ .../schedules/test_descript_pretty_print.py | 89 +++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 tests/filecheck/schedules/test_descript_pretty_print.py diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py index 4826d9c8..3bba88ad 100644 --- a/src/xtc/schedules/loop_nest.py +++ b/src/xtc/schedules/loop_nest.py @@ -122,6 +122,127 @@ def tiles_to_sizes(self) -> dict[str, int]: tiles_to_sizes[loop] = size return tiles_to_sizes + def pretty_print(self, indent: int = 0) -> str: + """Return a human-readable representation of the loop nest. + + The output format uses a compact notation: + - `loop X` for a regular loop over dimension X + - `tile(X, N)` for a tile of size N on dimension X + - `split(X, start, end)` for a split segment on dimension X + - `// annotation` for vectorized, parallelized, unroll(N) + - `...` for the innermost body + + Example output: + loop i // parallelized + loop k + loop j + tile(j, 16) // vectorized + ... + + Args: + indent: The initial indentation level (number of spaces). + + Returns: + A multi-line string representing the loop nest structure. + """ + lines: list[str] = [] + + # Build mapping from tile loop name to (axis, size) + tiles_info: dict[str, tuple[str, int]] = {} + for axis, tile_loops in self.tiles.items(): + for loop_name, size in tile_loops.items(): + tiles_info[loop_name] = (axis, size) + + # Build mapping from split loop name to (axis, start, end) + splits_info: dict[str, tuple[str, int, int | None]] = {} + for axis, axis_splits in self.splits.items(): + split_starts = list(axis_splits.values()) + for i, (loop_name, start) in enumerate(axis_splits.items()): + end = split_starts[i + 1] if i + 1 < len(split_starts) else None + splits_info[loop_name] = (axis, start, end) + + # Map split loop names to their child nodes + split_to_child: dict[str, LoopNestNode] = {} + for child in self.children: + if child.split_origin is not None: + axis = child.split_origin.axis + if axis in self.splits: + for loop_name, start in self.splits[axis].items(): + if start == child.split_origin.start: + split_to_child[loop_name] = child + break + + # Group splits by axis for same-level printing + axis_to_splits: dict[str, list[str]] = {} + for loop_name, (axis, _, _) in splits_info.items(): + if loop_name in self.interchange: + if axis not in axis_to_splits: + axis_to_splits[axis] = [] + axis_to_splits[axis].append(loop_name) + + processed_splits: set[str] = set() + current_indent = indent + + for loop_name in self.interchange: + # Skip already processed splits + if loop_name in processed_splits: + continue + + # Check if this is a split + if loop_name in splits_info: + axis, _, _ = splits_info[loop_name] + axis_split_names = axis_to_splits.get(axis, [loop_name]) + processed_splits.update(axis_split_names) + + # Print all splits of this axis at the same level + for split_name in axis_split_names: + split_axis, start, end = splits_info[split_name] + end_str = str(end) if end is not None else "..." + line = f"split({split_axis}, {start}, {end_str})" + line = self._add_annotations(line, split_name) + lines.append(" " * current_indent + line) + + # Use child's pretty_print if available + if split_name in split_to_child: + child_output = split_to_child[split_name].pretty_print( + current_indent + 2 + ) + lines.append(child_output) + else: + lines.append(" " * (current_indent + 2) + "...") + else: + # Regular loop (tile or base dimension) + if loop_name in tiles_info: + axis, size = tiles_info[loop_name] + line = f"tile({axis}, {size})" + else: + # Extract basename (last part after /) + basename = loop_name.split("/")[-1] + line = f"loop {basename}" + + line = self._add_annotations(line, loop_name) + lines.append(" " * current_indent + line) + current_indent += 2 + + # Add body if no splits were encountered + if not processed_splits: + lines.append(" " * current_indent + "...") + + return "\n".join(lines) + + def _add_annotations(self, line: str, loop_name: str) -> str: + """Add annotations (parallelized, vectorized, unroll) to a loop line.""" + annotations: list[str] = [] + if loop_name in self.parallelize: + annotations.append("parallelized") + if loop_name in self.vectorize: + annotations.append("vectorized") + if loop_name in self.unroll: + annotations.append(f"unroll({self.unroll[loop_name]})") + if annotations: + line += " // " + ", ".join(annotations) + return line + @dataclass class LoopsDimsMapper: diff --git a/tests/filecheck/schedules/test_descript_pretty_print.py b/tests/filecheck/schedules/test_descript_pretty_print.py new file mode 100644 index 00000000..11c9af9c --- /dev/null +++ b/tests/filecheck/schedules/test_descript_pretty_print.py @@ -0,0 +1,89 @@ +# RUN: python %s --simple 2>&1 | filecheck %s --check-prefix=CHECK-SIMPLE +# RUN: python %s --tiled 2>&1 | filecheck %s --check-prefix=CHECK-TILED +# RUN: python %s --vectorized 2>&1 | filecheck %s --check-prefix=CHECK-VECTORIZED +# RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL +# RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT + +import sys +from xtc.schedules.parsing import ScheduleParser +from xtc.schedules.descript import ScheduleInterpreter + +parser = ScheduleParser() +abstract_axis = ["i", "j", "k"] +interpreter = ScheduleInterpreter(abstract_axis) + +if "--simple" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--tiled" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}, "j#16": {}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--vectorized" in sys.argv: + spec = {"i": {}, "k": {}, "j": {}, "j#16": {"vectorize": True}} + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--full" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {}, + "j": {}, + "j#32": {}, + "j#16": {"vectorize": True, "unroll": 4}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +elif "--split" in sys.argv: + spec = { + "i": {}, + "j[:128]": {"k": {}, "k#32": {}}, + "j[128:]": {"k": {}, "k#16": {"vectorize": True}}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + +# CHECK-SIMPLE: loop i +# CHECK-SIMPLE-NEXT: loop k +# CHECK-SIMPLE-NEXT: loop j +# CHECK-SIMPLE-NEXT: ... + +# CHECK-TILED: loop i +# CHECK-TILED-NEXT: loop k +# CHECK-TILED-NEXT: loop j +# CHECK-TILED-NEXT: tile(j, 16) +# CHECK-TILED-NEXT: ... + +# CHECK-VECTORIZED: loop i +# CHECK-VECTORIZED-NEXT: loop k +# CHECK-VECTORIZED-NEXT: loop j +# CHECK-VECTORIZED-NEXT: tile(j, 16) // vectorized +# CHECK-VECTORIZED-NEXT: ... + +# CHECK-FULL: loop i // parallelized +# CHECK-FULL-NEXT: loop k +# CHECK-FULL-NEXT: loop j +# CHECK-FULL-NEXT: tile(j, 32) +# CHECK-FULL-NEXT: tile(j, 16) // vectorized, unroll(4) +# CHECK-FULL-NEXT: ... + +# CHECK-SPLIT: loop i +# CHECK-SPLIT-NEXT: split(j, 0, 128) +# CHECK-SPLIT-NEXT: loop j +# CHECK-SPLIT-NEXT: loop k +# CHECK-SPLIT-NEXT: tile(k, 32) +# CHECK-SPLIT-NEXT: ... +# CHECK-SPLIT-NEXT: split(j, 128, ...) +# CHECK-SPLIT-NEXT: loop j +# CHECK-SPLIT-NEXT: loop k +# CHECK-SPLIT-NEXT: tile(k, 16) // vectorized +# CHECK-SPLIT-NEXT: ... From 1601ae7d7a3da9dc820d7fa091c2c0a589503013 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 2 Feb 2026 01:53:27 +0100 Subject: [PATCH 4/8] descript: build LoopNest from Scheduler --- src/xtc/backends/jir/JIRScheduler.py | 30 +++++++ src/xtc/backends/mlir/MlirScheduler.py | 53 +++++++++++++ src/xtc/backends/tvm/TVMScheduler.py | 26 +++++++ src/xtc/itf/schd/scheduler.py | 17 ++++ .../filecheck/schedules/test_get_descript.py | 78 +++++++++++++++++++ 5 files changed, 204 insertions(+) create mode 100644 tests/filecheck/schedules/test_get_descript.py diff --git a/src/xtc/backends/jir/JIRScheduler.py b/src/xtc/backends/jir/JIRScheduler.py index ac40e930..dde8312c 100644 --- a/src/xtc/backends/jir/JIRScheduler.py +++ b/src/xtc/backends/jir/JIRScheduler.py @@ -7,6 +7,7 @@ import xtc.itf as itf from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest import xtc.backends.jir as backend __all__ = [ @@ -347,3 +348,32 @@ def distributed_buffer_at( def get_schedule_str(self) -> str: return str(JIRSchedule(scheduler=self)) + + @override + def get_loop_nest(self) -> LoopNest: + transformer = self._transformer + dims = list(transformer.dims.keys()) + + loop_nest = LoopNest(abstract_dims=dims) + root_node = loop_nest.build_root_node(self._backend.payload_name) + + # Build tiles mapping + for axis, axis_tiles in transformer.tiles.items(): + for tile_name, size in axis_tiles.items(): + root_node.tiles[axis][tile_name] = size + + # Build interchange + root_node.interchange = ( + list(transformer.order) if transformer.order else dims[:] + ) + + # Build vectorization list + root_node.vectorize = list(transformer.vectorized) + + # Build parallelization list + root_node.parallelize = list(transformer.parallelized) + + # Build unroll mapping + root_node.unroll = dict(transformer.unrolled) + + return loop_nest diff --git a/src/xtc/backends/mlir/MlirScheduler.py b/src/xtc/backends/mlir/MlirScheduler.py index 605477b3..e8e2b14b 100644 --- a/src/xtc/backends/mlir/MlirScheduler.py +++ b/src/xtc/backends/mlir/MlirScheduler.py @@ -5,6 +5,7 @@ from typing_extensions import override from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest, LoopNestNode, SplitOrigin import xtc.itf as itf import xtc.backends.mlir as backend @@ -203,6 +204,58 @@ def distributed_buffer_at( axis, input_idx, memory_axes, root=root ) + @override + def get_loop_nest(self) -> LoopNest: + node_sched = self._current_scheduler + dims = node_sched.dims[:] + + loop_nest = LoopNest(abstract_dims=dims) + root_node = loop_nest.build_root_node(node_sched.node_name) + + # Collect all split names and their info (axis, start, end) + split_info: dict[str, tuple[str, int, int | None]] = {} + for axis, axis_splits in node_sched.splits.items(): + split_starts = list(axis_splits.values()) + for i, (split_name, start) in enumerate(axis_splits.items()): + end = split_starts[i + 1] if i + 1 < len(split_starts) else None + split_info[split_name] = (axis, start, end) + + def populate_node(node: LoopNestNode, perm: list[str]) -> None: + """Populate node with data for loops in its permutation.""" + node.interchange = list(perm) + perm_set = set(perm) + for axis, axis_tiles in node_sched.tiles.items(): + for tile_name, size in axis_tiles.items(): + if tile_name in perm_set: + if axis not in node.tiles: + node.tiles[axis] = {} + node.tiles[axis][tile_name] = size + node.vectorize = [v for v in node_sched.vectorization if v in perm_set] + node.parallelize = [p for p in node_sched.parallelization if p in perm_set] + node.unroll = { + k: v for k, v in node_sched.unrolling.items() if k in perm_set + } + + # Process each root in permutation + for root, perm in node_sched.permutation.items(): + if root in split_info: + # This root is a split - create child node + axis, start, end = split_info[root] + child = LoopNestNode( + root=root, + tiles={d: {} for d in dims}, + split_origin=SplitOrigin(axis=axis, start=start, end=end), + ) + populate_node(child, perm) + root_node.add_child(child) + else: + # This is the main root + populate_node(root_node, perm) + for axis, axis_splits in node_sched.splits.items(): + root_node.splits[axis] = dict(axis_splits) + + return loop_nest + class MlirSchedule(itf.schd.Schedule): def __init__( diff --git a/src/xtc/backends/tvm/TVMScheduler.py b/src/xtc/backends/tvm/TVMScheduler.py index 8e51249b..e439e850 100644 --- a/src/xtc/backends/tvm/TVMScheduler.py +++ b/src/xtc/backends/tvm/TVMScheduler.py @@ -12,6 +12,7 @@ from xtc.utils.math import pow2divisor from xtc.itf.schd.scheduler import DEFAULT_ROOT +from xtc.schedules.loop_nest import LoopNest import xtc.backends.tvm as backend import xtc.itf as itf @@ -511,6 +512,31 @@ def _get_plain_schedule(self) -> TVMPlainSchedule: def __str__(self) -> str: return str(self._get_plain_schedule()) + @override + def get_loop_nest(self) -> LoopNest: + loop_nest = LoopNest(abstract_dims=self.dims[:]) + root_node = loop_nest.build_root_node(self._op.name or "op") + + # Build tiles mapping + for axis, axis_tiles in self.tiles.items(): + for tile_name, size in axis_tiles.items(): + if tile_name != axis: + root_node.tiles[axis][tile_name] = size + + # Build interchange + root_node.interchange = list(self.permutation) + + # Build vectorization list + root_node.vectorize = list(self.vectorization) + + # Build parallelization list + root_node.parallelize = list(self.parallelization) + + # Build unroll mapping + root_node.unroll = dict(self.unrolling) + + return loop_nest + class TVMSchedule(itf.schd.Schedule): def __init__(self, scheduler: "TVMScheduler", schedule_impl: ScheduleImpl) -> None: diff --git a/src/xtc/itf/schd/scheduler.py b/src/xtc/itf/schd/scheduler.py index 3ea1febb..f2abc465 100644 --- a/src/xtc/itf/schd/scheduler.py +++ b/src/xtc/itf/schd/scheduler.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from .schedule import Schedule import xtc.itf +from xtc.schedules.loop_nest import LoopNest DEFAULT_ROOT = "." ROOT_SEP = "/" @@ -291,3 +292,19 @@ def distributed_buffer_at( root: the parent split (or the operator's absolute root) """ ... + + @abstractmethod + def get_loop_nest(self) -> LoopNest: + """Return a LoopNest representation of the current schedule. + + This method constructs a LoopNest object that describes the loop + structure resulting from the scheduling transformations applied + so far. The LoopNest can be used for visualization (via pretty_print) + or further analysis. + + Returns: + LoopNest: A tree structure representing the scheduled loop nest, + including tiles, splits, interchange order, and annotations + (vectorization, parallelization, unrolling). + """ + ... diff --git a/tests/filecheck/schedules/test_get_descript.py b/tests/filecheck/schedules/test_get_descript.py new file mode 100644 index 00000000..50d1e32c --- /dev/null +++ b/tests/filecheck/schedules/test_get_descript.py @@ -0,0 +1,78 @@ +# RUN: python %s --mlir 2>&1 | filecheck %s --check-prefix=CHECK-MLIR +# RUN: python %s --tvm 2>&1 | filecheck %s --check-prefix=CHECK-TVM +# REQUIRES: module_tvm + +import sys +import xtc.graphs.xtc.op as O + +I, J, K, dtype = 4, 32, 512, "float32" +a = O.tensor((I, K), dtype, name="A") +b = O.tensor((K, J), dtype, name="B") + +with O.graph(name="matmul") as gb: + O.matmul(a, b, name="C") + +graph = gb.graph + +if "--mlir" in sys.argv: + from xtc.backends.mlir import Backend + +elif "--tvm" in sys.argv: + from xtc.backends.tvm import Backend + +else: + assert False + +impl = Backend(graph) +sch = impl.get_scheduler() +sch.set_dims(["I", "J", "K"]) +sch.tile("I", {"I0": 2}) +sch.tile("J", {"J0": 16}) +sch.interchange(["K", "I", "J", "I0", "J0"]) +sch.unroll({"I0": 2}) +sch.vectorize(["J0"]) + +loop_nest = sch.get_loop_nest() +print(loop_nest.root_node.pretty_print()) + +# CHECK-MLIR: loop K +# CHECK-MLIR-NEXT: loop I +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(I, 2) // unroll(2) +# CHECK-MLIR-NEXT: tile(J, 16) // vectorized +# CHECK-MLIR-NEXT: ... + +# CHECK-TVM: loop K +# CHECK-TVM-NEXT: loop I +# CHECK-TVM-NEXT: loop J +# CHECK-TVM-NEXT: tile(I, 2) // unroll(2) +# CHECK-TVM-NEXT: tile(J, 16) // vectorized +# CHECK-TVM-NEXT: ... + +# Test with split (MLIR only - TVM does not support split) +if "--mlir" in sys.argv: + print("---") + + impl2 = Backend(graph) + sch2 = impl2.get_scheduler() + sch2.set_dims(["I", "J", "K"]) + sch2.split("I", {"I_lo": 0, "I_hi": 2}) + sch2.tile("J", {"J0": 16}, root="./I_lo") + sch2.tile("J", {"J0": 16}, root="./I_hi") + sch2.interchange(["K", "I_lo", "I_hi"]) + sch2.interchange(["J", "J0"], root="./I_lo") + sch2.interchange(["J", "J0"], root="./I_hi") + + loop_nest2 = sch2.get_loop_nest() + print(loop_nest2.root_node.pretty_print()) + +# CHECK-MLIR: --- +# CHECK-MLIR-NEXT: loop K +# CHECK-MLIR-NEXT: split(I, 0, 2) +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(J, 16) +# CHECK-MLIR-NEXT: ... +# CHECK-MLIR-NEXT: split(I, 2, ...) +# CHECK-MLIR-NEXT: loop J +# CHECK-MLIR-NEXT: tile(J, 16) +# CHECK-MLIR-NEXT: ... From 697e397b4ad5f3fe3a3ba5cc06c44691523e192d Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 2 Feb 2026 06:37:54 +0100 Subject: [PATCH 5/8] descript: encapsulate info on loops --- src/xtc/backends/mlir/MlirScheduler.py | 19 ++- src/xtc/schedules/loop_nest.py | 185 +++++++++++++------------ 2 files changed, 106 insertions(+), 98 deletions(-) diff --git a/src/xtc/backends/mlir/MlirScheduler.py b/src/xtc/backends/mlir/MlirScheduler.py index e8e2b14b..cdf382b5 100644 --- a/src/xtc/backends/mlir/MlirScheduler.py +++ b/src/xtc/backends/mlir/MlirScheduler.py @@ -5,7 +5,7 @@ from typing_extensions import override from xtc.itf.schd.scheduler import DEFAULT_ROOT -from xtc.schedules.loop_nest import LoopNest, LoopNestNode, SplitOrigin +from xtc.schedules.loop_nest import LoopNest, LoopNestNode, LoopInfo, SplitOrigin import xtc.itf as itf import xtc.backends.mlir as backend @@ -212,13 +212,12 @@ def get_loop_nest(self) -> LoopNest: loop_nest = LoopNest(abstract_dims=dims) root_node = loop_nest.build_root_node(node_sched.node_name) - # Collect all split names and their info (axis, start, end) - split_info: dict[str, tuple[str, int, int | None]] = {} + # Assign splits to root_node first for axis, axis_splits in node_sched.splits.items(): - split_starts = list(axis_splits.values()) - for i, (split_name, start) in enumerate(axis_splits.items()): - end = split_starts[i + 1] if i + 1 < len(split_starts) else None - split_info[split_name] = (axis, start, end) + root_node.splits[axis] = dict(axis_splits) + + # Build mapper to get splits_info + mapper = LoopInfo.build_from_node(root_node) def populate_node(node: LoopNestNode, perm: list[str]) -> None: """Populate node with data for loops in its permutation.""" @@ -238,9 +237,9 @@ def populate_node(node: LoopNestNode, perm: list[str]) -> None: # Process each root in permutation for root, perm in node_sched.permutation.items(): - if root in split_info: + if root in mapper.splits_info: # This root is a split - create child node - axis, start, end = split_info[root] + axis, start, end = mapper.splits_info[root] child = LoopNestNode( root=root, tiles={d: {} for d in dims}, @@ -251,8 +250,6 @@ def populate_node(node: LoopNestNode, perm: list[str]) -> None: else: # This is the main root populate_node(root_node, perm) - for axis, axis_splits in node_sched.splits.items(): - root_node.splits[axis] = dict(axis_splits) return loop_nest diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py index 3bba88ad..d8a05e58 100644 --- a/src/xtc/schedules/loop_nest.py +++ b/src/xtc/schedules/loop_nest.py @@ -4,7 +4,7 @@ # from __future__ import annotations -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar from dataclasses import dataclass, field from .exceptions import ScheduleValidationError @@ -102,26 +102,6 @@ class LoopNestNode(Node["LoopNestNode"]): parallelize: list[str] = field(default_factory=list) unroll: dict[str, int] = field(default_factory=dict) - @property - def splits_to_sizes(self) -> dict[str, int]: - splits_to_sizes: dict[str, int] = {} - for axis in self.splits: - last_start = None - for loop_name, start in reversed(self.splits[axis].items()): - if last_start is not None: - size_of_split = last_start - start - splits_to_sizes[loop_name] = size_of_split - last_start = start - return splits_to_sizes - - @property - def tiles_to_sizes(self) -> dict[str, int]: - tiles_to_sizes: dict[str, int] = {} - for tiles in self.tiles.values(): - for loop, size in tiles.items(): - tiles_to_sizes[loop] = size - return tiles_to_sizes - def pretty_print(self, indent: int = 0) -> str: """Return a human-readable representation of the loop nest. @@ -147,19 +127,9 @@ def pretty_print(self, indent: int = 0) -> str: """ lines: list[str] = [] - # Build mapping from tile loop name to (axis, size) - tiles_info: dict[str, tuple[str, int]] = {} - for axis, tile_loops in self.tiles.items(): - for loop_name, size in tile_loops.items(): - tiles_info[loop_name] = (axis, size) - - # Build mapping from split loop name to (axis, start, end) - splits_info: dict[str, tuple[str, int, int | None]] = {} - for axis, axis_splits in self.splits.items(): - split_starts = list(axis_splits.values()) - for i, (loop_name, start) in enumerate(axis_splits.items()): - end = split_starts[i + 1] if i + 1 < len(split_starts) else None - splits_info[loop_name] = (axis, start, end) + mapper = LoopInfo.build_from_node(self) + tiles_info = mapper.tiles_info + splits_info = mapper.splits_info # Map split loop names to their child nodes split_to_child: dict[str, LoopNestNode] = {} @@ -245,54 +215,85 @@ def _add_annotations(self, line: str, loop_name: str) -> str: @dataclass -class LoopsDimsMapper: - """Maps loop names to their corresponding axis names. +class LoopInfo: + """Maps loop names to their corresponding axis names and metadata. This class tracks the relationship between loop identifiers (from tiling and splitting transformations) and the original dimension axes they - derive from. + derive from, along with their sizes and positions. Attributes: - tiles_to_axis: Maps tile loop names to their parent axis. - splits_to_axis: Maps split loop names to their parent axis. dims: List of original dimension names. + tiles_info: Maps tile loop names to (axis, size) tuples. + splits_info: Maps split loop names to (axis, start, end) tuples. """ - tiles_to_axis: dict[str, str] - splits_to_axis: dict[str, str] dims: list[str] + tiles_info: dict[str, tuple[str, int]] = field(default_factory=dict) + splits_info: dict[str, tuple[str, int, int | None]] = field(default_factory=dict) + + @property + def tiles_to_axis(self) -> dict[str, str]: + return {name: axis for name, (axis, _) in self.tiles_info.items()} + + @property + def splits_to_axis(self) -> dict[str, str]: + return {name: axis for name, (axis, _, _) in self.splits_info.items()} @property def loops_to_axis(self) -> dict[str, str]: - loops_to_axis = ( + return ( self.tiles_to_axis | self.splits_to_axis | dict(zip(self.dims, self.dims)) ) - return loops_to_axis - @staticmethod - def build_from_nodes(nodes: list[LoopNestNode]) -> LoopsDimsMapper: - tiles_to_axis = {} - splits_to_axis = {} - dims = set() - for node in nodes: - tiles_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.tiles)) - splits_to_axis.update(LoopsDimsMapper._get_subloops_to_axis(node.splits)) - refined_loops = list(tiles_to_axis) + list(splits_to_axis) - for node in nodes: - dims.update( - [loop for loop in node.interchange if loop not in refined_loops] - ) - dims.update(tiles_to_axis.values()) - dims.update(splits_to_axis.values()) - return LoopsDimsMapper(tiles_to_axis, splits_to_axis, list(dims)) + @property + def splits_to_sizes(self) -> dict[str, int]: + return { + name: end - start + for name, (_, start, end) in self.splits_info.items() + if end is not None + } @staticmethod - def _get_subloops_to_axis(subloops: dict[str, dict[str, Any]]) -> dict[str, str]: - loop_to_axis: dict[str, str] = {} - for axis_name, subloops in subloops.items(): - for loop_name in subloops: - loop_to_axis[loop_name] = axis_name - return loop_to_axis + def build_from_node(node: LoopNestNode) -> LoopInfo: + tiles_info: dict[str, tuple[str, int]] = {} + splits_info: dict[str, tuple[str, int, int | None]] = {} + dims: dict[ + str, None + ] = {} # ordered set: insertion order preserved, no duplicates + + def collect(n: LoopNestNode) -> None: + # Build tiles_info: tile_name -> (axis, size) + for axis, tile_loops in n.tiles.items(): + for loop_name, size in tile_loops.items(): + tiles_info[loop_name] = (axis, size) + + # Build splits_info: split_name -> (axis, start, end) + for axis, axis_splits in n.splits.items(): + sorted_splits = sorted(axis_splits.items(), key=lambda kv: kv[1]) + for i, (loop_name, start) in enumerate(sorted_splits): + end = ( + sorted_splits[i + 1][1] if i + 1 < len(sorted_splits) else None + ) + splits_info[loop_name] = (axis, start, end) + + # Collect dims in stable order + refined_loops = set(tiles_info) | set(splits_info) + for loop in n.interchange: + if loop not in refined_loops: + dims[loop] = None + for axis, _ in tiles_info.values(): + dims[axis] = None + for axis, _, _ in splits_info.values(): + dims[axis] = None + + # Recurse on children + for child in n.children: + collect(child) + + collect(node) + + return LoopInfo(list(dims), tiles_info, splits_info) @dataclass @@ -332,15 +333,16 @@ def build_root_node(self, root: str) -> LoopNestNode: return node def check(self): - self._check_use_defined_dims() + assert self.root_node is not None + info = LoopInfo.build_from_node(self.root_node) + self._check_use_defined_dims(info) self._check_vectorization_consistency() - self._check_tiling_consistency() - self._check_sizes() + self._check_tiling_consistency(info) + self._check_sizes(info) - def _check_use_defined_dims(self): - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + def _check_use_defined_dims(self, info: LoopInfo): for dim in self.abstract_dims: - if dim not in mapper.dims: + if dim not in info.dims: raise ScheduleValidationError(f"{dim} defined but never used") def _check_vectorization_consistency(self): @@ -354,42 +356,51 @@ def _check_vectorization_consistency(self): f"Inner loop {loop_name} isn't vectorized but an outer one is." ) - def _check_tiling_consistency(self) -> None: - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + def _check_tiling_consistency(self, info: LoopInfo) -> None: seen_axes: dict[str, int | None] = {} for sched in self.nodes: for loop_name in sched.interchange: - if loop_name in mapper.dims: + if loop_name in info.dims: seen_axes[loop_name] = None - elif loop_name in mapper.tiles_to_axis: - axis = mapper.tiles_to_axis[loop_name] - size = sched.tiles_to_sizes[loop_name] + elif loop_name in info.splits_to_axis: + axis = info.splits_to_axis[loop_name] + seen_axes[axis] = sched.splits[axis][loop_name] + elif loop_name in info.tiles_to_axis: + axis = info.tiles_to_axis[loop_name] + size = sched.tiles[axis][loop_name] if axis not in seen_axes: raise ScheduleValidationError( f""" `{axis}#{size}`: {axis} has not been materialized yet. """ ) - seen_axes[axis] = sched.tiles[axis][loop_name] + seen_axes[axis] = size - def _check_sizes(self): - mapper = LoopsDimsMapper.build_from_nodes(self.nodes) + def _check_sizes(self, info: LoopInfo): current_size_of_split: dict[str, int | None] = {} for sched in self.nodes: current_size_of_tile: dict[str, int] = {} + if sched.split_origin is not None: + axis = sched.split_origin.axis + start = sched.split_origin.start + end = sched.split_origin.end + if end is not None and start is not None: + current_size_of_split[axis] = end - start + else: + current_size_of_split[axis] = None for loop_name in sched.interchange: - axis = mapper.loops_to_axis[loop_name] + axis = info.loops_to_axis[loop_name] current_sizes = ( - {d: None for d in mapper.dims} + {d: None for d in info.dims} | current_size_of_split | current_size_of_tile ) loop_size = None - if loop_name in mapper.dims: + if loop_name in info.dims: if loop_name not in current_size_of_split: current_size_of_split[loop_name] = None - elif loop_name in mapper.tiles_to_axis: + elif loop_name in info.tiles_to_axis: loop_size = sched.tiles[axis][loop_name] LoopNest._must_be_smaller_routine( new_size=loop_size, @@ -399,10 +410,10 @@ def _check_sizes(self): ) current_size_of_tile[axis] = loop_size elif ( - loop_name in mapper.splits_to_axis - and loop_name in sched.splits_to_sizes + loop_name in info.splits_to_axis + and loop_name in info.splits_to_sizes ): - loop_size = sched.splits_to_sizes[loop_name] + loop_size = info.splits_to_sizes[loop_name] LoopNest._must_be_smaller_routine( new_size=loop_size, current_sizes=current_sizes, From c27bb7f9477b3d62cb00845bb55561ac69a2cca2 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 2 Feb 2026 07:07:56 +0100 Subject: [PATCH 6/8] descript: code cleaning --- src/xtc/cli/mlir_loop.py | 2 +- src/xtc/schedules/descript.py | 26 +++++++++---------- src/xtc/schedules/loop_nest.py | 4 --- src/xtc/schedules/ttile/scheme_to_xtc.py | 2 +- .../schedules/test_descript_parsing_errors.py | 6 ++--- .../schedules/test_descript_slice_bigger.py | 2 +- .../schedules/test_descript_slice_smaller.py | 2 +- .../schedules/test_matmul_descript_mlir.py | 2 +- .../schedules/test_matmul_descript_tvm.py | 2 +- 9 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/xtc/cli/mlir_loop.py b/src/xtc/cli/mlir_loop.py index 09f9f8fb..2bd93991 100644 --- a/src/xtc/cli/mlir_loop.py +++ b/src/xtc/cli/mlir_loop.py @@ -136,7 +136,7 @@ def build_node_scheduler( descript_scheduler( scheduler=scheduler, node_name=node_name, - abstract_axis=scheduler.backend.dims, + abstract_dims=scheduler.backend.dims, spec=normal_schedule, ) op.attributes.pop("loop.schedule", None) diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 026891fc..d1c2576c 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -21,7 +21,7 @@ def descript_scheduler( scheduler: Scheduler, node_name: str, - abstract_axis: list[str], + abstract_dims: list[str], spec: dict[str, dict[str, Any]], ) -> None: """Apply a schedule specification to a scheduler. @@ -31,22 +31,22 @@ def descript_scheduler( Args: scheduler: The scheduler to apply the schedule to. node_name: The name of the root node to schedule. - abstract_axis: The list of abstract axis names (e.g., ["m", "n", "k"]). + abstract_dims: The list of abstract axis names (e.g., ["m", "n", "k"]). spec: The schedule specification as a nested dict. """ - descript = Descript(scheduler=scheduler, abstract_axis=abstract_axis) + descript = Descript(scheduler=scheduler, abstract_dims=abstract_dims) descript.apply(node_name=node_name, spec=spec) class ScheduleInterpreter: """Interprets a parsed ScheduleSpec AST into a LoopNest.""" - def __init__(self, abstract_axis: list[str]): - self.abstract_axis = abstract_axis + def __init__(self, abstract_dims: list[str]): + self.abstract_dims = abstract_dims def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest: """Interpret a schedule specification into a LoopNest.""" - loop_nest = LoopNest(abstract_dims=self.abstract_axis) + loop_nest = LoopNest(abstract_dims=self.abstract_dims) root_node = loop_nest.build_root_node(root) self._interpret_spec_into_node(spec, root_node, root, head=[]) return loop_nest @@ -61,7 +61,7 @@ def _interpret_spec_into_node( """Interpret a schedule spec into an existing node (mutates node).""" # Track state during interpretation sizes: dict[str, int] = {} - previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis} + previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_dims} interchange: list[str] = list(head) for item in spec.items: @@ -124,7 +124,7 @@ def _interpret_split( # Create a child node for the nested schedule child_node = LoopNestNode( root=new_root_name, - tiles={a: {} for a in self.abstract_axis}, + tiles={a: {} for a in self.abstract_dims}, split_origin=SplitOrigin(axis=axis_name, start=x, end=y), ) node.add_child(child_node) @@ -176,9 +176,9 @@ def _interpret_axis( def _check_axis_existence(self, axis: str) -> None: """Check that an axis is defined.""" - if axis not in self.abstract_axis: + if axis not in self.abstract_dims: raise ScheduleInterpretError( - f"Axis {axis} is not a defined axis (defined axis: {self.abstract_axis})." + f"Axis {axis} is not a defined axis (defined axis: {self.abstract_dims})." ) def _apply_annotations( @@ -249,7 +249,7 @@ class Descript: """ scheduler: Scheduler - abstract_axis: list[str] + abstract_dims: list[str] def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: """Parse, interpret, validate, and apply a schedule specification. @@ -268,7 +268,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: ast = parser.parse(spec) # Interpret the AST into a LoopNest - interpreter = ScheduleInterpreter(self.abstract_axis) + interpreter = ScheduleInterpreter(self.abstract_dims) loop_nest = interpreter.interpret(ast, root=node_name) # Validate the loop nest loop_nest.check() @@ -277,7 +277,7 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None: def _apply_loop_nest(self, loop_nest: LoopNest) -> None: """Apply a LoopNest to the scheduler.""" - self.scheduler.set_dims(self.abstract_axis) + self.scheduler.set_dims(self.abstract_dims) if loop_nest.root_node is not None: self._apply_node(loop_nest.root_node) diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py index d8a05e58..078ff7f6 100644 --- a/src/xtc/schedules/loop_nest.py +++ b/src/xtc/schedules/loop_nest.py @@ -311,10 +311,6 @@ class LoopNest: abstract_dims: list[str] root_node: LoopNestNode | None = None - @property - def empty(self) -> bool: - return self.root_node is None - @property def nodes(self) -> list[LoopNestNode]: """Flatten the tree into a list of nodes. diff --git a/src/xtc/schedules/ttile/scheme_to_xtc.py b/src/xtc/schedules/ttile/scheme_to_xtc.py index f2c83364..99603bfb 100644 --- a/src/xtc/schedules/ttile/scheme_to_xtc.py +++ b/src/xtc/schedules/ttile/scheme_to_xtc.py @@ -723,7 +723,7 @@ def build_schedule_from_ttile( ) descript_scheduler( - scheduler=sch, node_name=name_op, abstract_axis=ldims, spec=spec_schedule + scheduler=sch, node_name=name_op, abstract_dims=ldims, spec=spec_schedule ) # And run it! diff --git a/tests/filecheck/schedules/test_descript_parsing_errors.py b/tests/filecheck/schedules/test_descript_parsing_errors.py index 49f0322b..676d5b81 100644 --- a/tests/filecheck/schedules/test_descript_parsing_errors.py +++ b/tests/filecheck/schedules/test_descript_parsing_errors.py @@ -24,7 +24,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {}, "K": {"unroll" : "hello"}, @@ -35,7 +35,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {"parallelize" : "hello"}, "K": {}, @@ -46,7 +46,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "I": {}, "K": {}, diff --git a/tests/filecheck/schedules/test_descript_slice_bigger.py b/tests/filecheck/schedules/test_descript_slice_bigger.py index 61f2090e..f295916a 100644 --- a/tests/filecheck/schedules/test_descript_slice_bigger.py +++ b/tests/filecheck/schedules/test_descript_slice_bigger.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler=sch, node_name="C", - abstract_axis=["i", "j", "k"], + abstract_dims=["i", "j", "k"], spec={ 'k': {}, 'j': {}, diff --git a/tests/filecheck/schedules/test_descript_slice_smaller.py b/tests/filecheck/schedules/test_descript_slice_smaller.py index bdf11093..551f001b 100644 --- a/tests/filecheck/schedules/test_descript_slice_smaller.py +++ b/tests/filecheck/schedules/test_descript_slice_smaller.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler=sch, node_name="C", - abstract_axis=["i", "j", "k"], + abstract_dims=["i", "j", "k"], spec={ 'k': {}, 'j': {}, diff --git a/tests/filecheck/schedules/test_matmul_descript_mlir.py b/tests/filecheck/schedules/test_matmul_descript_mlir.py index f3f517a5..814ae0e6 100644 --- a/tests/filecheck/schedules/test_matmul_descript_mlir.py +++ b/tests/filecheck/schedules/test_matmul_descript_mlir.py @@ -20,7 +20,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "K": {}, "I": {}, diff --git a/tests/filecheck/schedules/test_matmul_descript_tvm.py b/tests/filecheck/schedules/test_matmul_descript_tvm.py index fab85fc7..c484775a 100644 --- a/tests/filecheck/schedules/test_matmul_descript_tvm.py +++ b/tests/filecheck/schedules/test_matmul_descript_tvm.py @@ -21,7 +21,7 @@ descript_scheduler( scheduler = sch, node_name = "C", - abstract_axis = ["I","J","K"], + abstract_dims = ["I","J","K"], spec = { "K": {}, "I": {}, From bcdad844a5c18ba82fdd91afde16ff0802901125 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Wed, 4 Feb 2026 01:38:43 +0100 Subject: [PATCH 7/8] descript: expose buffer_at --- src/xtc/backends/tvm/TVMScheduler.py | 3 +++ src/xtc/schedules/descript.py | 6 ++++++ src/xtc/schedules/loop_nest.py | 11 ++++++++++- src/xtc/schedules/parsing.py | 16 ++++++++++++++++ .../schedules/test_descript_pretty_print.py | 18 ++++++++++++++++++ tests/filecheck/schedules/test_get_descript.py | 4 +++- 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/xtc/backends/tvm/TVMScheduler.py b/src/xtc/backends/tvm/TVMScheduler.py index e439e850..d0fe5609 100644 --- a/src/xtc/backends/tvm/TVMScheduler.py +++ b/src/xtc/backends/tvm/TVMScheduler.py @@ -535,6 +535,9 @@ def get_loop_nest(self) -> LoopNest: # Build unroll mapping root_node.unroll = dict(self.unrolling) + # Build buffer_at mapping + root_node.buffer_at = {axis: None for axis in self.write_caches} + return loop_nest diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index d1c2576c..66ce04ad 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -210,6 +210,9 @@ def _apply_annotations( if annotations.parallelize: node.parallelize.append(loop_name) + if annotations.buffer_specified: + node.buffer_at[loop_name] = annotations.buffer + def _check_splitting_intervals( self, item: SplitDecl, @@ -297,6 +300,9 @@ def _apply_node(self, node: LoopNestNode) -> None: self.scheduler.parallelize(node.parallelize, root=root) self.scheduler.unroll(node.unroll, root=root) + for axis, mtype in node.buffer_at.items(): + self.scheduler.buffer_at(axis, mtype=mtype, root=root) + # Recursively apply children for child in node.children: self._apply_node(child) diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py index 078ff7f6..91f6faef 100644 --- a/src/xtc/schedules/loop_nest.py +++ b/src/xtc/schedules/loop_nest.py @@ -92,6 +92,8 @@ class LoopNestNode(Node["LoopNestNode"]): vectorize: List of loops to vectorize. parallelize: List of loops to parallelize. unroll: Maps loop names to their unroll factors. + buffer_at: Buffer configuration per axis. Maps axis names to optional + memory types (mtype). None means default memory type. """ root: str @@ -101,6 +103,7 @@ class LoopNestNode(Node["LoopNestNode"]): vectorize: list[str] = field(default_factory=list) parallelize: list[str] = field(default_factory=list) unroll: dict[str, int] = field(default_factory=dict) + buffer_at: dict[str, str | None] = field(default_factory=dict) def pretty_print(self, indent: int = 0) -> str: """Return a human-readable representation of the loop nest. @@ -201,7 +204,7 @@ def pretty_print(self, indent: int = 0) -> str: return "\n".join(lines) def _add_annotations(self, line: str, loop_name: str) -> str: - """Add annotations (parallelized, vectorized, unroll) to a loop line.""" + """Add annotations (parallelized, vectorized, unroll, buffer) to a loop line.""" annotations: list[str] = [] if loop_name in self.parallelize: annotations.append("parallelized") @@ -209,6 +212,12 @@ def _add_annotations(self, line: str, loop_name: str) -> str: annotations.append("vectorized") if loop_name in self.unroll: annotations.append(f"unroll({self.unroll[loop_name]})") + if loop_name in self.buffer_at: + mtype = self.buffer_at[loop_name] + if mtype is not None: + annotations.append(f"buffer({mtype})") + else: + annotations.append("buffer") if annotations: line += " // " + ", ".join(annotations) return line diff --git a/src/xtc/schedules/parsing.py b/src/xtc/schedules/parsing.py index 419feede..03cc8c90 100644 --- a/src/xtc/schedules/parsing.py +++ b/src/xtc/schedules/parsing.py @@ -22,12 +22,17 @@ class Annotations: unroll_specified: True if unroll was explicitly requested. vectorize: True if vectorization was requested. parallelize: True if parallelization was requested. + buffer: The memory type for the buffer. None means default memory type. + Only meaningful when buffer_specified is True. + buffer_specified: True if buffer was explicitly requested. """ unroll_factor: int | None = None unroll_specified: bool = False vectorize: bool = False parallelize: bool = False + buffer: str | None = None + buffer_specified: bool = False @dataclass(frozen=True) @@ -145,6 +150,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations unroll_specified = False vectorize = False parallelize = False + buffer: str | None = None + buffer_specified = False for key, param in value.items(): if key == "unroll": @@ -172,6 +179,13 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations f'`{{"parallelize" = {param}}}`: parameterized parallelization not implemented.' ) parallelize = param + elif key == "buffer": + if not isinstance(param, str): + raise ScheduleParseError( + f'`{{"buffer" = {param}}}`: buffer parameter should be a string (mtype).' + ) + buffer = None if param == "default" else param + buffer_specified = True else: raise ScheduleParseError(f"Unknown annotation on {context}: {key}") @@ -180,6 +194,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations unroll_specified=unroll_specified, vectorize=vectorize, parallelize=parallelize, + buffer=buffer, + buffer_specified=buffer_specified, ) def _parse_split_syntax( diff --git a/tests/filecheck/schedules/test_descript_pretty_print.py b/tests/filecheck/schedules/test_descript_pretty_print.py index 11c9af9c..bfd5074b 100644 --- a/tests/filecheck/schedules/test_descript_pretty_print.py +++ b/tests/filecheck/schedules/test_descript_pretty_print.py @@ -3,6 +3,7 @@ # RUN: python %s --vectorized 2>&1 | filecheck %s --check-prefix=CHECK-VECTORIZED # RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL # RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT +# RUN: python %s --buffer 2>&1 | filecheck %s --check-prefix=CHECK-BUFFER import sys from xtc.schedules.parsing import ScheduleParser @@ -52,6 +53,17 @@ loop_nest = interpreter.interpret(ast, root="C") print(loop_nest.root_node.pretty_print()) +elif "--buffer" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {"buffer": "default"}, + "j": {"buffer": "shared"}, + "j#16": {"vectorize": True}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + # CHECK-SIMPLE: loop i # CHECK-SIMPLE-NEXT: loop k # CHECK-SIMPLE-NEXT: loop j @@ -87,3 +99,9 @@ # CHECK-SPLIT-NEXT: loop k # CHECK-SPLIT-NEXT: tile(k, 16) // vectorized # CHECK-SPLIT-NEXT: ... + +# CHECK-BUFFER: loop i // parallelized +# CHECK-BUFFER-NEXT: loop k // buffer +# CHECK-BUFFER-NEXT: loop j // buffer(shared) +# CHECK-BUFFER-NEXT: tile(j, 16) // vectorized +# CHECK-BUFFER-NEXT: ... diff --git a/tests/filecheck/schedules/test_get_descript.py b/tests/filecheck/schedules/test_get_descript.py index 50d1e32c..f4606da0 100644 --- a/tests/filecheck/schedules/test_get_descript.py +++ b/tests/filecheck/schedules/test_get_descript.py @@ -31,6 +31,8 @@ sch.interchange(["K", "I", "J", "I0", "J0"]) sch.unroll({"I0": 2}) sch.vectorize(["J0"]) +if "--tvm" in sys.argv: + sch.buffer_at("J") loop_nest = sch.get_loop_nest() print(loop_nest.root_node.pretty_print()) @@ -44,7 +46,7 @@ # CHECK-TVM: loop K # CHECK-TVM-NEXT: loop I -# CHECK-TVM-NEXT: loop J +# CHECK-TVM-NEXT: loop J // buffer # CHECK-TVM-NEXT: tile(I, 2) // unroll(2) # CHECK-TVM-NEXT: tile(J, 16) // vectorized # CHECK-TVM-NEXT: ... From 7eac95a4e45ae633b13a4e6fdeefead3c37aee0c Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Wed, 4 Feb 2026 01:51:11 +0100 Subject: [PATCH 8/8] descript: expose pack_at --- src/xtc/backends/tvm/TVMScheduler.py | 5 +++ src/xtc/schedules/descript.py | 6 +++ src/xtc/schedules/loop_nest.py | 14 +++++- src/xtc/schedules/parsing.py | 44 +++++++++++++++++++ .../schedules/test_descript_pretty_print.py | 18 ++++++++ .../filecheck/schedules/test_get_descript.py | 3 +- 6 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/xtc/backends/tvm/TVMScheduler.py b/src/xtc/backends/tvm/TVMScheduler.py index d0fe5609..59c58ff0 100644 --- a/src/xtc/backends/tvm/TVMScheduler.py +++ b/src/xtc/backends/tvm/TVMScheduler.py @@ -538,6 +538,11 @@ def get_loop_nest(self) -> LoopNest: # Build buffer_at mapping root_node.buffer_at = {axis: None for axis in self.write_caches} + # Build pack_at mapping + root_node.pack_at = { + axis: (input_idx, None, pad) for axis, input_idx, pad in self.read_buffers + } + return loop_nest diff --git a/src/xtc/schedules/descript.py b/src/xtc/schedules/descript.py index 66ce04ad..50e4de5b 100644 --- a/src/xtc/schedules/descript.py +++ b/src/xtc/schedules/descript.py @@ -213,6 +213,9 @@ def _apply_annotations( if annotations.buffer_specified: node.buffer_at[loop_name] = annotations.buffer + if annotations.pack_specified and annotations.pack is not None: + node.pack_at[loop_name] = annotations.pack + def _check_splitting_intervals( self, item: SplitDecl, @@ -303,6 +306,9 @@ def _apply_node(self, node: LoopNestNode) -> None: for axis, mtype in node.buffer_at.items(): self.scheduler.buffer_at(axis, mtype=mtype, root=root) + for axis, (input_idx, mtype, pad) in node.pack_at.items(): + self.scheduler.pack_at(axis, input_idx, mtype=mtype, pad=pad, root=root) + # Recursively apply children for child in node.children: self._apply_node(child) diff --git a/src/xtc/schedules/loop_nest.py b/src/xtc/schedules/loop_nest.py index 91f6faef..306c0c91 100644 --- a/src/xtc/schedules/loop_nest.py +++ b/src/xtc/schedules/loop_nest.py @@ -94,6 +94,9 @@ class LoopNestNode(Node["LoopNestNode"]): unroll: Maps loop names to their unroll factors. buffer_at: Buffer configuration per axis. Maps axis names to optional memory types (mtype). None means default memory type. + pack_at: Pack configuration per axis. Maps axis names to tuples of + (input_idx, mtype, pad). input_idx is the input buffer index, + mtype is the memory type (None for default), pad enables padding. """ root: str @@ -104,6 +107,7 @@ class LoopNestNode(Node["LoopNestNode"]): parallelize: list[str] = field(default_factory=list) unroll: dict[str, int] = field(default_factory=dict) buffer_at: dict[str, str | None] = field(default_factory=dict) + pack_at: dict[str, tuple[int, str | None, bool]] = field(default_factory=dict) def pretty_print(self, indent: int = 0) -> str: """Return a human-readable representation of the loop nest. @@ -204,7 +208,7 @@ def pretty_print(self, indent: int = 0) -> str: return "\n".join(lines) def _add_annotations(self, line: str, loop_name: str) -> str: - """Add annotations (parallelized, vectorized, unroll, buffer) to a loop line.""" + """Add annotations (parallelized, vectorized, unroll, buffer, pack) to a loop line.""" annotations: list[str] = [] if loop_name in self.parallelize: annotations.append("parallelized") @@ -218,6 +222,14 @@ def _add_annotations(self, line: str, loop_name: str) -> str: annotations.append(f"buffer({mtype})") else: annotations.append("buffer") + if loop_name in self.pack_at: + input_idx, mtype, pad = self.pack_at[loop_name] + parts = [str(input_idx)] + if mtype is not None: + parts.append(mtype) + if pad: + parts.append("pad") + annotations.append(f"pack({', '.join(parts)})") if annotations: line += " // " + ", ".join(annotations) return line diff --git a/src/xtc/schedules/parsing.py b/src/xtc/schedules/parsing.py index 03cc8c90..d1e17a53 100644 --- a/src/xtc/schedules/parsing.py +++ b/src/xtc/schedules/parsing.py @@ -25,6 +25,9 @@ class Annotations: buffer: The memory type for the buffer. None means default memory type. Only meaningful when buffer_specified is True. buffer_specified: True if buffer was explicitly requested. + pack: Pack configuration as (input_idx, mtype, pad). mtype is None for default. + Only meaningful when pack_specified is True. + pack_specified: True if pack was explicitly requested. """ unroll_factor: int | None = None @@ -33,6 +36,8 @@ class Annotations: parallelize: bool = False buffer: str | None = None buffer_specified: bool = False + pack: tuple[int, str | None, bool] | None = None + pack_specified: bool = False @dataclass(frozen=True) @@ -152,6 +157,8 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations parallelize = False buffer: str | None = None buffer_specified = False + pack: tuple[int, str | None, bool] | None = None + pack_specified = False for key, param in value.items(): if key == "unroll": @@ -186,6 +193,9 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations ) buffer = None if param == "default" else param buffer_specified = True + elif key == "pack": + pack = self._parse_pack_param(param, context) + pack_specified = True else: raise ScheduleParseError(f"Unknown annotation on {context}: {key}") @@ -196,8 +206,42 @@ def _parse_annotations(self, value: dict[str, Any], context: str) -> Annotations parallelize=parallelize, buffer=buffer, buffer_specified=buffer_specified, + pack=pack, + pack_specified=pack_specified, ) + def _parse_pack_param( + self, param: Any, context: str + ) -> tuple[int, str | None, bool]: + """Parse pack parameter into (input_idx, mtype, pad) tuple.""" + if not isinstance(param, (list, tuple)) or len(param) != 3: + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: pack parameter should be a tuple (input_idx, mtype, pad).' + ) + + input_idx, mtype, pad = param + + if not isinstance(input_idx, int): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: input_idx should be an integer.' + ) + + if mtype is not None and not isinstance(mtype, str): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: mtype should be a string or None.' + ) + + if not isinstance(pad, bool): + raise ScheduleParseError( + f'`{{"pack" = {param}}}` on {context}: pad should be a boolean.' + ) + + # Convert "default" to None for mtype + if mtype == "default": + mtype = None + + return (input_idx, mtype, pad) + def _parse_split_syntax( self, declaration: str ) -> tuple[str, int | None, int | None]: diff --git a/tests/filecheck/schedules/test_descript_pretty_print.py b/tests/filecheck/schedules/test_descript_pretty_print.py index bfd5074b..54ed443d 100644 --- a/tests/filecheck/schedules/test_descript_pretty_print.py +++ b/tests/filecheck/schedules/test_descript_pretty_print.py @@ -4,6 +4,7 @@ # RUN: python %s --full 2>&1 | filecheck %s --check-prefix=CHECK-FULL # RUN: python %s --split 2>&1 | filecheck %s --check-prefix=CHECK-SPLIT # RUN: python %s --buffer 2>&1 | filecheck %s --check-prefix=CHECK-BUFFER +# RUN: python %s --pack 2>&1 | filecheck %s --check-prefix=CHECK-PACK import sys from xtc.schedules.parsing import ScheduleParser @@ -64,6 +65,17 @@ loop_nest = interpreter.interpret(ast, root="C") print(loop_nest.root_node.pretty_print()) +elif "--pack" in sys.argv: + spec = { + "i": {"parallelize": True}, + "k": {"pack": (0, "default", False)}, + "j": {"pack": (1, "shared", True)}, + "j#16": {"vectorize": True}, + } + ast = parser.parse(spec) + loop_nest = interpreter.interpret(ast, root="C") + print(loop_nest.root_node.pretty_print()) + # CHECK-SIMPLE: loop i # CHECK-SIMPLE-NEXT: loop k # CHECK-SIMPLE-NEXT: loop j @@ -105,3 +117,9 @@ # CHECK-BUFFER-NEXT: loop j // buffer(shared) # CHECK-BUFFER-NEXT: tile(j, 16) // vectorized # CHECK-BUFFER-NEXT: ... + +# CHECK-PACK: loop i // parallelized +# CHECK-PACK-NEXT: loop k // pack(0) +# CHECK-PACK-NEXT: loop j // pack(1, shared, pad) +# CHECK-PACK-NEXT: tile(j, 16) // vectorized +# CHECK-PACK-NEXT: ... diff --git a/tests/filecheck/schedules/test_get_descript.py b/tests/filecheck/schedules/test_get_descript.py index f4606da0..47888bb4 100644 --- a/tests/filecheck/schedules/test_get_descript.py +++ b/tests/filecheck/schedules/test_get_descript.py @@ -33,6 +33,7 @@ sch.vectorize(["J0"]) if "--tvm" in sys.argv: sch.buffer_at("J") + sch.pack_at("I", 0, pad=True) loop_nest = sch.get_loop_nest() print(loop_nest.root_node.pretty_print()) @@ -45,7 +46,7 @@ # CHECK-MLIR-NEXT: ... # CHECK-TVM: loop K -# CHECK-TVM-NEXT: loop I +# CHECK-TVM-NEXT: loop I // pack(0, pad) # CHECK-TVM-NEXT: loop J // buffer # CHECK-TVM-NEXT: tile(I, 2) // unroll(2) # CHECK-TVM-NEXT: tile(J, 16) // vectorized