Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdist_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
mlir-sdist==21.1.2.2026012001
mlir-sdist==21.1.2.2026030601
mlir==21.1.2.2025091603
xtc-mlir==21.1.2.2
30 changes: 29 additions & 1 deletion src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2026 The XTC Project Authors
#
from typing import cast
from dataclasses import dataclass
from mlir.dialects import transform
from mlir.dialects.transform import (
Expand Down Expand Up @@ -312,6 +313,12 @@ def _generate_node_scheduling(
schedule=schedule,
sched_state=sched_state,
)
if loop_name in schedule.write_buffers:
self._write_buffer(
loop_name=loop_name,
schedule=schedule,
sched_state=sched_state,
)

# Manage the strip-mining
if loop_name in schedule.vectorization:
Expand Down Expand Up @@ -505,7 +512,6 @@ def _distribute_loops(
mesh="processor_mesh",
axis=schedule.distribution[loop_name],
)
assert len(distribute_command.results) == 1
new_loop = distribute_command.results[0]
sched_state.all_loops[loop_name] = new_loop
# Annotate the resulting loop if successfully generated
Expand Down Expand Up @@ -542,6 +548,28 @@ def _pack_buffer(
input_idx=input_idx,
)

def _write_buffer(
self,
loop_name: str,
schedule: MlirNodeSchedule,
sched_state: SchedulingState,
):
from .MlirGraphBackend import MlirGraphBackend
from .MlirNodeBackend import MlirNodeBackend

assert self._mlir_schedule is not None
graph_backend = cast(MlirGraphBackend, self._mlir_schedule.scheduler.backend)
node_backend = cast(MlirNodeBackend, graph_backend.nodes[schedule.node_name])
Comment on lines +561 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it needed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to pass the pyright type checks

Copy link
Contributor

@qaco qaco Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use assert isinstance(var,type) instead ? Imho cast is ambiguous (in terms of intent)

output_idx = len(node_backend.np_inputs_spec())
with InsertionPoint(transform.ApplyPatternsOp(sched_state.handle).patterns):
memref.ApplyFoldMemrefAliasOpsPatternsOp()
if "sdist" in self._mlir_program.mlir_extensions:
assert sdist_transform is not None
sdist_transform.SDistLocalBufferAtOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do under the hood ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses the same SDist primitive as pack_at, which automatically creates a local buffer at a particular loop level. The transformation automatically infer which buffer is a read and/or write.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok could you test the pipeline at different levels, aka transform, then transformed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not, but SDist is not tested in CI. Maybe We should ask @guillon how to test XTC also with SDist

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes if we use it for all targets we should test it

target=sched_state.handle,
input_idx=output_idx,
)


class MlirProgramApplyTransformPass:
def __init__(
Expand Down
7 changes: 7 additions & 0 deletions src/xtc/backends/mlir/MlirNodeScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MlirNodeSchedule:
parallelization: list[str]
unrolling: dict[str, int]
packed_buffers: dict[str, list[int]]
write_buffers: list[str]
memory_mesh: dict[str, int]
processor_mesh: dict[str, int]
distribution: dict[str, str]
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
self.parallelization: list[str] = []
self.unrolling: dict[str, int] = {}
self.packed_buffers: dict[str, list[int]] = {}
self.write_buffers: list[str] = []
self.memory_mesh: dict[str, int] = {}
self.processor_mesh: dict[str, int] = {}
self.distribution: dict[str, str] = {}
Expand All @@ -112,6 +114,7 @@ def mlir_node_schedule(self) -> MlirNodeSchedule:
unrolling=self.unrolling,
memory_mesh=self.memory_mesh,
packed_buffers=self.packed_buffers,
write_buffers=self.write_buffers,
processor_mesh=self.processor_mesh,
distribution=self.distribution,
distributed_buffers=self.distributed_buffers,
Expand Down Expand Up @@ -178,6 +181,10 @@ def pack_at(
else:
self.packed_buffers[axis_key].append(input_idx)

def buffer_at(self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT):
axis_key = f"{root}{ROOT_SEP}{axis}"
self.write_buffers.append(axis_key)

def define_memory_mesh(self, axes: dict[str, int]):
assert len(self.memory_mesh) == 0, "Memory mesh has already been defined"
self.memory_mesh = axes
Expand Down
13 changes: 9 additions & 4 deletions src/xtc/backends/mlir/MlirScheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def interchange(self, permutation: list[str], root: str = DEFAULT_ROOT) -> None:
def buffer_at(
self, axis: str, mtype: str | None = None, root: str = DEFAULT_ROOT
) -> None:
assert mtype is None or mtype == "global"
# TODO: not implemented for now
pass
# The current implementation exclusively rely on SDist, but upstream
# transform dialect may be used for some cases.
assert mtype is None or mtype == "global" or mtype == "local"
if mtype is None or mtype == "global":
self._require_extension("sdist", weak=True)
else:
self._require_extension("sdist")
self._current_scheduler.buffer_at(axis, mtype, root=root)

@override
def pack_at(
Expand All @@ -144,7 +149,7 @@ def pack_at(
pad: bool = False,
root: str = DEFAULT_ROOT,
) -> None:
# The current implemntation exclusively rely on SDist, but upstream
# The current implementation exclusively rely on SDist, but upstream
# transform dialect may be used for some cases.
assert mtype is None or mtype == "global" or mtype == "local"
if pad:
Expand Down
10 changes: 5 additions & 5 deletions tests/filecheck/search/test_conv_oo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
utils.print_exhaustive_samples(backend, strategy, 100)

# CHECK: schedule O0: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O1: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 1}, 'f': {'./f1': 1}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 1, './w1': 1, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O2: [1, 1, 2, 16, 1, 1, 1]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 1, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: schedule O3: [1, 1, 2, 16, 1, 1, 3]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 1}, 'h': {'./h1': 1}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 3}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 1, './c1': 3, './s1': 1, './r1': 1, './b1': 1}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: sample 0: [1, 1, 1, 1, 1, 1, 1]
# CHECK-NEXT: sample 1: [1, 1, 1, 1, 1, 1, 3]
# CHECK-NEXT: sample 2: [1, 1, 1, 1, 1, 7, 1]
Expand Down Expand Up @@ -99,4 +99,4 @@
# CHECK-NEXT: sample 76: [2, 2, 2, 8, 1, 1, 1]
# CHECK-NEXT: sample 77: [2, 2, 2, 16, 1, 1, 1]
# CHECK-NEXT: stats {'filtered': 78, 'all': 384}
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
# CHECK-NEXT: [MlirNodeSchedule(node_name='%2_0', node_ident='__xtc_id_%2_0_', dims=['b', 'h', 'w', 'f'], loop_stamps=[], splits={}, tiles={'b': {}, 'h': {}, 'w': {}, 'f': {}}, permutation={'.': ['./b', './h', './w', './f']}, vectorization=[], parallelization=[], unrolling={}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={}), MlirNodeSchedule(node_name='%2', node_ident='__xtc_id_%2_', dims=['b', 'h', 'w', 'f', 'r', 's', 'c'], loop_stamps=[], splits={}, tiles={'b': {'./b1': 2}, 'h': {'./h1': 2}, 'w': {'./w1': 2}, 'f': {'./f1': 16}, 'r': {'./r1': 1}, 's': {'./s1': 1}, 'c': {'./c1': 1}}, permutation={'.': ['./b', './r', './s', './c', './h', './w', './f', './b1', './r1', './s1', './c1', './h1', './w1', './f1']}, vectorization=['./f1'], parallelization=[], unrolling={'./f1': 16, './w1': 2, './h1': 2, './c1': 1, './s1': 1, './r1': 1, './b1': 2}, packed_buffers={}, write_buffers=[], memory_mesh={}, processor_mesh={}, distribution={}, distributed_buffers={})]
Loading