Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
sdist_transform = None
pass

from xtc.itf.schd.scheduler import ROOT_SEP
from xtc.utils.ext_tools import transform_opts

from .MlirProgram import RawMlirProgram
Expand Down Expand Up @@ -293,7 +294,11 @@ def _generate_node_scheduling(
schedule=schedule, root=loop_name, sched_state=sched_state
)
continue

axis_split = split_state.loop_dim_by_split.get(root)
if axis_split is not None and not (
schedule.is_base(loop_name) or schedule.is_tile(loop_name)
):
loop_name = root + ROOT_SEP + axis_split
# Bufferization
if loop_name in schedule.distributed_buffers.keys():
self._distribute_buffer(
Expand Down Expand Up @@ -340,6 +345,7 @@ def _generate_tiling_insns(
state_of_tiling: dict[str, int] = {dim: 1 for dim in schedule.dims}
candidate_state_of_tiling = state_of_tiling.copy()
previous_root = ""
split_state = SplitState(schedule.splits, previous_root)
for loc_root, permutation in reversed(schedule.permutation.items()):
if len(loc_root) == len(previous_root):
# Reset the view on the state of tiling (we are jumping into
Expand All @@ -348,11 +354,14 @@ def _generate_tiling_insns(
else:
# Update the state of tiling
state_of_tiling = candidate_state_of_tiling.copy()

for loop in reversed(permutation):
# The loop needs to be base or tile
if not (schedule.is_tile(loop) or schedule.is_base(loop)):
continue
axis_split = split_state.loop_dim_by_split.get(loc_root)
if axis_split is not None:
loop = loc_root + ROOT_SEP + axis_split
else:
continue

# Fetch the dimension knowledge
dim_of_loop = schedule.dim_of_tile(loop)
Expand Down Expand Up @@ -472,7 +481,7 @@ def _unroll(
for dim_name in reversed(permutation):
if (
dim_name in schedule.unrolling
and not dim_name in schedule.vectorization
and dim_name not in schedule.vectorization
):
assert self._named_sequence is not None
loop_unroll(
Expand Down
4 changes: 1 addition & 3 deletions src/xtc/backends/mlir/MlirNodeScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from typing_extensions import override
from dataclasses import dataclass, asdict
from pprint import pformat
from xtc.itf.schd.scheduler import DEFAULT_ROOT
from xtc.itf.schd.scheduler import DEFAULT_ROOT, ROOT_SEP

__all__ = [
"MlirNodeScheduler",
"MlirNodeSchedule",
]

ROOT_SEP = "/"


def basename(loop_name: str) -> str:
return loop_name.split(ROOT_SEP)[-1]
Expand Down
3 changes: 3 additions & 0 deletions src/xtc/itf/schd/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import xtc.itf

DEFAULT_ROOT = "."
ROOT_SEP = "/"
SPLIT_LEFT_SEP = "["
SPLIT_RIGHT_SEP = "]"


class Scheduler(ABC):
Expand Down
64 changes: 36 additions & 28 deletions src/xtc/schedules/descript.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass, field
import re
from typing_extensions import override
from xtc.itf.schd.scheduler import Scheduler
from xtc.itf.schd.scheduler import Scheduler, ROOT_SEP, SPLIT_LEFT_SEP, SPLIT_RIGHT_SEP


class ScheduleParseError(RuntimeError):
Expand Down Expand Up @@ -60,7 +60,7 @@ class SplitDecl:
def __str__(self) -> str:
start_str = "" if self.start is None else str(self.start)
end_str = "" if self.end is None else str(self.end)
decl = f"{self.axis}[{start_str}:{end_str}]"
decl = f"{self.axis}{SPLIT_LEFT_SEP}{start_str}:{end_str}{SPLIT_RIGHT_SEP}"
return decl


Expand Down Expand Up @@ -221,22 +221,25 @@ class ScheduleInterpreter:

def __init__(self, abstract_axis: list[str]):
self.abstract_axis = abstract_axis
self.root_to_dim: dict[str, str] = {}
self.dim_to_axis: dict[str, str] = {}

def interpret(self, spec: ScheduleSpec, root: str) -> LoopNest:
"""Interpret a schedule specification into a LoopNest."""
return self._interpret_spec(spec, root, head=[])
return self._interpret_spec(spec, root)

def _interpret_spec(
self, spec: ScheduleSpec, root: str, head: list[str]
) -> LoopNest:
def _interpret_spec(self, spec: ScheduleSpec, root: str) -> LoopNest:
"""Interpret a schedule spec recursively."""
loop_nest = LoopNest(abstract_dims=self.abstract_axis)
slice = loop_nest.build_slice(root)

# Track state during interpretation
sizes: dict[str, int] = {}
previous_cut: dict[str, int | None] = {a: 0 for a in self.abstract_axis}
interchange: list[str] = list(head)
interchange: list[str] = []
# Only the first root is not in root_to_dim
if root in self.root_to_dim:
interchange.append(self.root_to_dim[root])

for item in spec.items:
if isinstance(item, SplitDecl):
Expand All @@ -249,7 +252,6 @@ def _interpret_spec(
elif isinstance(item, AxisDecl):
loop_name = self._interpret_axis(item, interchange)
self._apply_annotations(item.annotations, loop_name, sizes, slice)

# Check that all splits are complete
for axis, cut in previous_cut.items():
if cut is not None and cut != 0:
Expand Down Expand Up @@ -294,13 +296,14 @@ def _interpret_split(
if axis_name not in slice.splits:
slice.splits[axis_name] = {}
new_dim_index = len(slice.splits[axis_name])
new_dim_name = f"{axis_name}[{new_dim_index}]"
new_root_name = f"{root}/{new_dim_name}"
new_dim_name = f"{axis_name}{SPLIT_LEFT_SEP}{new_dim_index}{SPLIT_RIGHT_SEP}"
new_root_name = f"{root}{ROOT_SEP}{new_dim_name}"
slice.splits[axis_name][new_dim_name] = x
interchange.append(new_dim_name)

self.dim_to_axis[new_dim_name] = axis_name
self.root_to_dim[new_root_name] = new_dim_name
# Recursively interpret the nested schedule
inner_nest = self._interpret_spec(item.body, new_root_name, head=[axis_name])
inner_nest = self._interpret_spec(item.body, new_root_name)
loop_nest.slices += inner_nest.slices

def _interpret_tile(
Expand All @@ -321,7 +324,6 @@ def _interpret_tile(
slice.tiles[item.axis][loop_name] = item.size
sizes[loop_name] = item.size
interchange.append(loop_name)

return loop_name

def _interpret_axis(
Expand All @@ -332,13 +334,13 @@ def _interpret_axis(
"""Interpret a direct axis reference. Returns the loop name."""
axis_name = item.axis
self._check_axis_existence(axis_name)

# Unreachable when built from a Python dict (because keys
# can't be duplicated).
if axis_name in interchange:
raise ScheduleInterpretError(
f"Axis {axis_name} is scheduled twice (or more)."
)
for loop_name in interchange:
if self.dim_to_axis.get(loop_name, loop_name) == axis_name:
raise ScheduleInterpretError(
f"Axis {axis_name} is scheduled twice (or more)."
)

interchange.append(axis_name)
return axis_name
Expand Down Expand Up @@ -478,21 +480,23 @@ class LoopNestSlice:

root: str
tiles: dict[str, dict[str, int]]
splits: dict[str, dict[str, int]] = field(default_factory=dict)
splits: dict[str, dict[str, int | None]] = field(default_factory=dict)
interchange: list[str] = field(default_factory=list)
vectorize: list[str] = field(default_factory=list)
parallelize: list[str] = field(default_factory=list)
unroll: dict[str, int] = field(default_factory=dict)

@property
def splits_to_sizes(self) -> dict[str, int]:
splits_to_sizes: dict[str, int] = {}
def splits_to_sizes(self) -> dict[str, int | None]:
splits_to_sizes: dict[str, int | None] = {}
for axis in self.splits:
last_start = None
for loop_name, start in reversed(self.splits[axis].items()):
if last_start is not None:
if last_start is not None and start is not None:
size_of_split = last_start - start
splits_to_sizes[loop_name] = size_of_split
else:
splits_to_sizes[loop_name] = None
last_start = start
return splits_to_sizes

Expand Down Expand Up @@ -557,6 +561,8 @@ def _check_tiling_consistency(self) -> None:
seen_axes: dict[str, int | None] = {}
for sched in self.slices:
for loop_name in sched.interchange:
loop_name = mapper.splits_to_axis.get(loop_name, loop_name)

if loop_name in mapper.dims:
seen_axes[loop_name] = None
elif loop_name in mapper.tiles_to_axis:
Expand All @@ -575,7 +581,6 @@ def _check_sizes(self):
current_size_of_split: dict[str, int | None] = {}
for sched in self.slices:
current_size_of_tile: dict[str, int] = {}

for loop_name in sched.interchange:
axis = mapper.loops_to_axis[loop_name]
current_sizes = (
Expand Down Expand Up @@ -607,7 +612,9 @@ def _check_sizes(self):
loop_name=loop_name,
axis=axis,
)
current_size_of_split[axis] = loop_size
current_size_of_split[loop_name] = loop_size
elif loop_name in current_size_of_split:
current_size_of_split[axis] = current_size_of_split[loop_name]

if loop_name in sched.unroll:
unroll_factor = sched.unroll[loop_name]
Expand All @@ -618,10 +625,13 @@ def _check_sizes(self):

@staticmethod
def _must_be_smaller_routine(
new_size: int, current_sizes: dict[str, int | None], loop_name: str, axis: str
new_size: int | None,
current_sizes: dict[str, int | None],
loop_name: str,
axis: str,
):
old_size = current_sizes[axis]
if old_size is not None and new_size > old_size:
if old_size is not None and new_size is not None and new_size > old_size:
raise ScheduleValidationError(
f"""
Inner loop {loop_name} on axis {axis} must be smaller than outer loop.
Expand Down Expand Up @@ -683,10 +693,8 @@ def apply(self, node_name: str, spec: dict[str, dict[str, Any]]) -> None:
# Interpret the AST into a LoopNest
interpreter = ScheduleInterpreter(self.abstract_axis)
loop_nest = interpreter.interpret(ast, root=node_name)

# Validate the loop nest
loop_nest.check()

# Apply the schedule to the scheduler
self._apply_loop_nest(loop_nest)

Expand Down
Loading