Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/xtc/backends/jir/JIRScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import xtc.itf as itf
from xtc.itf.schd.scheduler import DEFAULT_ROOT
from xtc.schedules.loop_nest import LoopNest
import xtc.backends.jir as backend

__all__ = [
Expand Down Expand Up @@ -347,3 +348,32 @@ def distributed_buffer_at(

def get_schedule_str(self) -> str:
return str(JIRSchedule(scheduler=self))

@override
def get_loop_nest(self) -> LoopNest:
transformer = self._transformer
dims = list(transformer.dims.keys())

loop_nest = LoopNest(abstract_dims=dims)
root_node = loop_nest.build_root_node(self._backend.payload_name)

# Build tiles mapping
for axis, axis_tiles in transformer.tiles.items():
for tile_name, size in axis_tiles.items():
root_node.tiles[axis][tile_name] = size

# Build interchange
root_node.interchange = (
list(transformer.order) if transformer.order else dims[:]
)

# Build vectorization list
root_node.vectorize = list(transformer.vectorized)

# Build parallelization list
root_node.parallelize = list(transformer.parallelized)

# Build unroll mapping
root_node.unroll = dict(transformer.unrolled)

return loop_nest
50 changes: 50 additions & 0 deletions src/xtc/backends/mlir/MlirScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import override

from xtc.itf.schd.scheduler import DEFAULT_ROOT
from xtc.schedules.loop_nest import LoopNest, LoopNestNode, LoopInfo, SplitOrigin
import xtc.itf as itf
import xtc.backends.mlir as backend

Expand Down Expand Up @@ -203,6 +204,55 @@ def distributed_buffer_at(
axis, input_idx, memory_axes, root=root
)

@override
def get_loop_nest(self) -> LoopNest:
node_sched = self._current_scheduler
dims = node_sched.dims[:]

loop_nest = LoopNest(abstract_dims=dims)
root_node = loop_nest.build_root_node(node_sched.node_name)

# Assign splits to root_node first
for axis, axis_splits in node_sched.splits.items():
root_node.splits[axis] = dict(axis_splits)

# Build mapper to get splits_info
mapper = LoopInfo.build_from_node(root_node)

def populate_node(node: LoopNestNode, perm: list[str]) -> None:
"""Populate node with data for loops in its permutation."""
node.interchange = list(perm)
perm_set = set(perm)
for axis, axis_tiles in node_sched.tiles.items():
for tile_name, size in axis_tiles.items():
if tile_name in perm_set:
if axis not in node.tiles:
node.tiles[axis] = {}
node.tiles[axis][tile_name] = size
node.vectorize = [v for v in node_sched.vectorization if v in perm_set]
node.parallelize = [p for p in node_sched.parallelization if p in perm_set]
node.unroll = {
k: v for k, v in node_sched.unrolling.items() if k in perm_set
}

# Process each root in permutation
for root, perm in node_sched.permutation.items():
if root in mapper.splits_info:
# This root is a split - create child node
axis, start, end = mapper.splits_info[root]
child = LoopNestNode(
root=root,
tiles={d: {} for d in dims},
split_origin=SplitOrigin(axis=axis, start=start, end=end),
)
populate_node(child, perm)
root_node.add_child(child)
else:
# This is the main root
populate_node(root_node, perm)

return loop_nest


class MlirSchedule(itf.schd.Schedule):
def __init__(
Expand Down
34 changes: 34 additions & 0 deletions src/xtc/backends/tvm/TVMScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from xtc.utils.math import pow2divisor
from xtc.itf.schd.scheduler import DEFAULT_ROOT
from xtc.schedules.loop_nest import LoopNest
import xtc.backends.tvm as backend
import xtc.itf as itf

Expand Down Expand Up @@ -511,6 +512,39 @@ def _get_plain_schedule(self) -> TVMPlainSchedule:
def __str__(self) -> str:
return str(self._get_plain_schedule())

@override
def get_loop_nest(self) -> LoopNest:
loop_nest = LoopNest(abstract_dims=self.dims[:])
root_node = loop_nest.build_root_node(self._op.name or "op")

# Build tiles mapping
for axis, axis_tiles in self.tiles.items():
for tile_name, size in axis_tiles.items():
if tile_name != axis:
root_node.tiles[axis][tile_name] = size

# Build interchange
root_node.interchange = list(self.permutation)

# Build vectorization list
root_node.vectorize = list(self.vectorization)

# Build parallelization list
root_node.parallelize = list(self.parallelization)

# Build unroll mapping
root_node.unroll = dict(self.unrolling)

# Build buffer_at mapping
root_node.buffer_at = {axis: None for axis in self.write_caches}

# Build pack_at mapping
root_node.pack_at = {
axis: (input_idx, None, pad) for axis, input_idx, pad in self.read_buffers
}

return loop_nest


class TVMSchedule(itf.schd.Schedule):
def __init__(self, scheduler: "TVMScheduler", schedule_impl: ScheduleImpl) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/xtc/cli/mlir_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def build_node_scheduler(
descript_scheduler(
scheduler=scheduler,
node_name=node_name,
abstract_axis=scheduler.backend.dims,
abstract_dims=scheduler.backend.dims,
spec=normal_schedule,
)
op.attributes.pop("loop.schedule", None)
Expand Down
17 changes: 17 additions & 0 deletions src/xtc/itf/schd/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from .schedule import Schedule
import xtc.itf
from xtc.schedules.loop_nest import LoopNest

DEFAULT_ROOT = "."
ROOT_SEP = "/"
Expand Down Expand Up @@ -291,3 +292,19 @@ def distributed_buffer_at(
root: the parent split (or the operator's absolute root)
"""
...

@abstractmethod
def get_loop_nest(self) -> LoopNest:
"""Return a LoopNest representation of the current schedule.

This method constructs a LoopNest object that describes the loop
structure resulting from the scheduling transformations applied
so far. The LoopNest can be used for visualization (via pretty_print)
or further analysis.

Returns:
LoopNest: A tree structure representing the scheduled loop nest,
including tiles, splits, interchange order, and annotations
(vectorization, parallelization, unrolling).
"""
...
Loading