From 3082a73ae98cb20c6b6f901ba311abfa5f8e827f Mon Sep 17 00:00:00 2001 From: Sylvain Noiry Date: Thu, 5 Mar 2026 12:58:08 +0100 Subject: [PATCH] [Mlir] Fix loop distribution --- sdist_requirements.txt | 2 +- src/xtc/backends/mlir/MlirCompilerPasses.py | 35 ++++++++----------- .../backends/test_matmul_mlir_distributed.py | 6 ++-- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/sdist_requirements.txt b/sdist_requirements.txt index 0cf8abdf..98692e67 100644 --- a/sdist_requirements.txt +++ b/sdist_requirements.txt @@ -1,4 +1,4 @@ --index-url https://gitlab.inria.fr/api/v4/groups/corse/-/packages/pypi/simple -mlir-sdist==21.1.2.2026012001 +mlir-sdist==21.1.2.2026030601 mlir==21.1.2.2025091603 xtc-mlir==21.1.2.2 diff --git a/src/xtc/backends/mlir/MlirCompilerPasses.py b/src/xtc/backends/mlir/MlirCompilerPasses.py index fe6120de..a8781820 100644 --- a/src/xtc/backends/mlir/MlirCompilerPasses.py +++ b/src/xtc/backends/mlir/MlirCompilerPasses.py @@ -324,6 +324,8 @@ def _generate_node_scheduling( schedule=schedule, sched_state=sched_state, ) + if loop_name in schedule.distribution: + self._distribute_loop(loop_name, schedule, sched_state) # For now on, the focus is on the outermost loop if sched_state.all_loops: @@ -333,9 +335,6 @@ def _generate_node_scheduling( if schedule.unrolling: self._unroll(permutation, schedule, sched_state) - # Distribute loops - self._distribute_loops(permutation, schedule, sched_state) - return sched_state def _generate_tiling_insns( @@ -488,28 +487,24 @@ def _unroll( sched_state.all_loops[dim_name], schedule.unrolling[dim_name] ) - def _distribute_loops( + def _distribute_loop( self, - permutation: list[str], + loop_name: str, schedule: MlirNodeSchedule, sched_state: SchedulingState, ): - if len(schedule.distribution) == 0: - return - assert self._named_sequence is not None assert sdist_transform is not None - for loop_name in permutation: - if loop_name in schedule.distribution: - distribute_command = sdist_transform.SDistDistributeLoopOp( - target=sched_state.all_loops[loop_name], - mesh="processor_mesh", - axis=schedule.distribution[loop_name], - ) - assert len(distribute_command.results) == 1 - new_loop = distribute_command.results[0] - sched_state.all_loops[loop_name] = new_loop - # Annotate the resulting loop if successfully generated - transform.AnnotateOp(new_loop, loop_name) + distribute_command = sdist_transform.SDistDistributeLoopOp( + target=sched_state.all_loops[loop_name], + mesh="processor_mesh", + axis=schedule.distribution[loop_name], + ) + assert len(distribute_command.results) == 2 + new_loop = distribute_command.results[0] + sched_state.all_loops[loop_name] = new_loop + sched_state.handle = distribute_command.results[1] + # Annotate the resulting loop if successfully generated + transform.AnnotateOp(new_loop, loop_name) def _distribute_buffer( self, diff --git a/tests/filecheck/backends/test_matmul_mlir_distributed.py b/tests/filecheck/backends/test_matmul_mlir_distributed.py index ad535fbe..c4f722cc 100644 --- a/tests/filecheck/backends/test_matmul_mlir_distributed.py +++ b/tests/filecheck/backends/test_matmul_mlir_distributed.py @@ -72,15 +72,15 @@ # CHECK-NEXT: %5 = transform.sdist.local_buffer_at %tiled_linalg_op_2 tensor 1 : !transform.any_op -> !transform.any_op # CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %tiled_linalg_op_2 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) # CHECK-NEXT: transform.annotate %forall_op "./i" : !transform.any_op -# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_op tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: %transformed, %tiledOp = transform.sdist.distribute_loop %forall_op {axis = "px", mesh = "processor_mesh"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +# CHECK-NEXT: transform.annotate %transformed "./i" : !transform.any_op +# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiledOp tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) # CHECK-NEXT: transform.annotate %loops_5 "./j" : !transform.any_op # CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) # CHECK-NEXT: transform.annotate %loops_7 "./i1" : !transform.any_op # CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) # CHECK-NEXT: transform.annotate %loops_9 "./j1" : !transform.any_op # CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op -# CHECK-NEXT: %6 = transform.sdist.distribute_loop %forall_op {axis = "px", mesh = "processor_mesh"} : (!transform.any_op) -> !transform.any_op -# CHECK-NEXT: transform.annotate %6 "./i" : !transform.any_op # CHECK-NEXT: transform.yield # CHECK-NEXT: } # CHECK-NEXT: }