Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdist_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple
mlir-sdist==21.1.2.2026012001
mlir-sdist==21.1.2.2026030601
mlir==21.1.2.2025091603
xtc-mlir==21.1.2.2
35 changes: 15 additions & 20 deletions src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ def _generate_node_scheduling(
schedule=schedule,
sched_state=sched_state,
)
if loop_name in schedule.distribution:
self._distribute_loop(loop_name, schedule, sched_state)

# For now on, the focus is on the outermost loop
if sched_state.all_loops:
Expand All @@ -333,9 +335,6 @@ def _generate_node_scheduling(
if schedule.unrolling:
self._unroll(permutation, schedule, sched_state)

# Distribute loops
self._distribute_loops(permutation, schedule, sched_state)

return sched_state

def _generate_tiling_insns(
Expand Down Expand Up @@ -488,28 +487,24 @@ def _unroll(
sched_state.all_loops[dim_name], schedule.unrolling[dim_name]
)

def _distribute_loops(
def _distribute_loop(
self,
permutation: list[str],
loop_name: str,
schedule: MlirNodeSchedule,
sched_state: SchedulingState,
):
if len(schedule.distribution) == 0:
return
assert self._named_sequence is not None
assert sdist_transform is not None
for loop_name in permutation:
if loop_name in schedule.distribution:
distribute_command = sdist_transform.SDistDistributeLoopOp(
target=sched_state.all_loops[loop_name],
mesh="processor_mesh",
axis=schedule.distribution[loop_name],
)
assert len(distribute_command.results) == 1
new_loop = distribute_command.results[0]
sched_state.all_loops[loop_name] = new_loop
# Annotate the resulting loop if successfully generated
transform.AnnotateOp(new_loop, loop_name)
distribute_command = sdist_transform.SDistDistributeLoopOp(
target=sched_state.all_loops[loop_name],
mesh="processor_mesh",
axis=schedule.distribution[loop_name],
)
assert len(distribute_command.results) == 2
new_loop = distribute_command.results[0]
sched_state.all_loops[loop_name] = new_loop
sched_state.handle = distribute_command.results[1]
# Annotate the resulting loop if successfully generated
transform.AnnotateOp(new_loop, loop_name)

def _distribute_buffer(
self,
Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/backends/test_matmul_mlir_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@
# CHECK-NEXT: %5 = transform.sdist.local_buffer_at %tiled_linalg_op_2 tensor 1 : !transform.any_op -> !transform.any_op
# CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %forall_op "./i" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_op tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: %transformed, %tiledOp = transform.sdist.distribute_loop %forall_op {axis = "px", mesh = "processor_mesh"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %transformed "./i" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiledOp tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_5 "./j" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_7 "./i1" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_9 "./j1" : !transform.any_op
# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op
# CHECK-NEXT: %6 = transform.sdist.distribute_loop %forall_op {axis = "px", mesh = "processor_mesh"} : (!transform.any_op) -> !transform.any_op
# CHECK-NEXT: transform.annotate %6 "./i" : !transform.any_op
# CHECK-NEXT: transform.yield
# CHECK-NEXT: }
# CHECK-NEXT: }
Expand Down