Skip to content

Conversation

danielhua23
Copy link

@danielhua23 danielhua23 commented Sep 17, 2025

Motivation

this gemm reduce scatter version partition K of A and B then do matmul, for C, local_output of each rank own (M/world size, N) after reduce scatter
image

Closes #105

Test Result

this is my log

root@2-6-0:/workspace/iris/examples/13_gemm_reduce_scatter# python benchmark.py -b -v
M,N,K=8192,4608,9216 ; BLK_M,N,K=256,64,64
Rank 1/4 responsible for 2048 rows
total_blocks_M=32 x total_blocks_N=72 = total_tiles=2304
total_tiles_streamk=0 + total_blocking_tiles=2304 = total_tiles=2304
total_programs_streamk=288
M,N,K=8192,4608,9216 ; BLK_M,N,K=256,64,64
Rank 3/4 responsible for 2048 rows
total_blocks_M=32 x total_blocks_N=72 = total_tiles=2304
total_tiles_streamk=0 + total_blocking_tiles=2304 = total_tiles=2304
total_programs_streamk=288
M,N,K=8192,4608,9216 ; BLK_M,N,K=256,64,64
Rank 2/4 responsible for 2048 rows
total_blocks_M=32 x total_blocks_N=72 = total_tiles=2304
total_tiles_streamk=0 + total_blocking_tiles=2304 = total_tiles=2304
total_programs_streamk=288
M,N,K=8192,4608,9216 ; BLK_M,N,K=256,64,64
Rank 0/4 responsible for 2048 rows
total_blocks_M=32 x total_blocks_N=72 = total_tiles=2304
total_tiles_streamk=0 + total_blocking_tiles=2304 = total_tiles=2304
total_programs_streamk=288
252 registers used, 0 spills
252 registers used, 0 spills
252 registers used, 0 spills
252 registers used, 0 spills
[Iris] [1/4] Validating...
[Iris] [2/4] Validating...
[Iris] [3/4] Validating...
[Iris] [0/4] Validating...
[Iris] [1/4] Final C validation passed.
[Iris] [0/4] Final C validation passed.
[Iris] [2/4] Final C validation passed.
[Iris] [3/4] Final C validation passed.
[Iris] [0/4] Validation completed
[Iris] [3/4] Validation completed
[Iris] [0/4] Benchmarking...
[Iris] [1/4] Validation completed
[Iris] [2/4] Validation completed
[Iris] [3/4] Benchmarking...
[Iris] [2/4] Benchmarking...
[Iris] [1/4] Benchmarking...
[Iris] [1/4] tile matmul + reduce_scatter (grid=2304): 4.312 ms  645.426 tflops
[Iris] [3/4] tile matmul + reduce_scatter (grid=2304): 4.311 ms  645.577 tflops
[Iris] [2/4] tile matmul + reduce_scatter (grid=2304): 4.310 ms  645.679 tflops
[Iris] [0/4] tile matmul + reduce_scatter (grid=2304): 4.312 ms  645.435 tflops
{
    "world_size": 4,
    "m": 8192,
    "n": 4608,
    "k": 9216,
    "debug": false,
    "validate": true,
    "trace_tiles": false,
    "benchmark": true,
    "datatype": "fp16",
    "output_file": "log.json",
    "BLK_M": 256,
    "BLK_N": 64,
    "BLK_K": 64,
    "gsize_m": 6,
    "two_tiles": "True",
    "num_stages": 1,
    "num_warps": 8,
    "waves_per_eu": 0,
    "mfmaInstrSize": 16,
    "kpack": 2,
    "heap_size": 8589934592,
    "gemm_sms": 288,
    "total_sms": 304,
    "num_ranks": 4,
    "M": 8192,
    "N": 4608,
    "K": 36864,
    "gemm_registers": 252,
    "gemm_spills": 0,
    "success": true,
    "triton_tflops": 645.4351499618945,
    "triton_ms": 4.3120347690582275,
    "gemm_ms": 4.007320888458737,
    "gemm_experiments": 126
}

Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Daniel, for your contributions. First pass and this is looking mostly good but could you please add a test similar to
https://github.com/ROCm/iris/blob/main/tests/examples/test_all_load_bench.py
You will need to call the _worker function directly and pass the test arguments and you will need to modify the _worker to return local_output so you can validate that at the call site in the test.

@mawad-amd
Copy link
Collaborator

@danielhua23 I just tested this and some tests are passing and some are not. I am not sure if all the configs are possible and/or if we need to increase the tolerance for some data types:

Details

Running tests/examples/test_gemm_reduce_scatter_bench.py with 1 ranks
args=['tests/examples/test_gemm_reduce_scatter_bench.py'], test_file=tests/examples/test_gemm_reduce_scatter_bench.py, pytest_args=[]
============================= test session starts ==============================
platform linux -- Python 3.10.18, pytest-7.3.2, pluggy-1.5.0
rootdir: /work1/amd/muhaawad/git/amd/pdp/iris
plugins: rerunfailures-14.0, xdoctest-1.2.0, hypothesis-5.35.1, xdist-3.3.1, cpp-2.3.0, anyio-4.10.0, flakefinder-1.1.0, typeguard-4.3.0
collected 18 items

tests/examples/test_gemm_reduce_scatter_bench.py FFFFFFFFF....F....      [100%]

=================================== FAILURES ===================================
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype0] ______________

dtype = torch.float16, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stdout call -----------------------------
M,N,K=64,64,64 ; BLK_M,N,K=32,32,16
Rank 0/1 responsible for 64 rows
total_blocks_M=2 x total_blocks_N=2 = total_tiles=4
total_tiles_streamk=4 + total_blocking_tiles=0 = total_tiles=4
total_programs_streamk=288
32 registers used, 0 spills
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype1] ______________

dtype = torch.bfloat16, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype2] ______________

dtype = torch.float32, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
>       run_experiment()

tests/examples/test_gemm_reduce_scatter_bench.py:154: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
    local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
    result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
    kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
    next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

src = <triton._C.libtriton.ir.module object at 0x7faccfc95580>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')

    @staticmethod
    def make_llir(src, metadata, options):
        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        # custom_lds_size is an experimental parameter that defines amount of LDS available
        # for one thread block. Measured in bytes.
        #
        # If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
        # LDS size is determined by provided arch name.
        custom_lds_size = 0
        amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
        passes.convert.add_triton_scf_to_cf(pm)
        passes.convert.add_index_to_llvmir(pm)
    
        amd.passes.ttgpuir.add_allocate_shared_memory(pm)
        # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
        ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
        ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
        ##    of the value of kernel arg `allow_flush_denorm`.
        ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
        ##    depends on the value of kernel arg `allow_flush_denorm`.
        ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
        ##    For now it is used as a controller for developers only.
        __HIP_FTZ = True
        amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
    
        passes.convert.add_cf_to_llvmir(pm)
        passes.convert.add_arith_to_llvmir(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
    
        if options.schedule_hint.lower() != "none":
            amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
    
        # This can not be moved below the di_scope pass
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
    
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
    
        amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
>       pm.run(mod)
E       RuntimeError: PassManager::run failed

/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
  tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
    %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
    %c32_i32 = arith.constant 32 : i32
    %c6_i32 = arith.constant 6 : i32
    %true = arith.constant true
    %c36_i32 = arith.constant 36 : i32
    %c8_i32 = arith.constant 8 : i32
    %c15_i32 = arith.constant 15 : i32
    %c16_i32 = arith.constant 16 : i32
    %c31_i32 = arith.constant 31 : i32
    %c1_i32 = arith.constant 1 : i32
    %c288_i32 = arith.constant 288 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.remsi %0, %c8_i32 : i32
    %2 = arith.muli %1, %c36_i32 : i32
    %3 = arith.divsi %0, %c8_i32 : i32
    %4 = arith.addi %2, %3 : i32
    %5 = arith.addi %arg7, %c31_i32 : i32
    %6 = arith.divsi %5, %c32_i32 : i32
    %7 = arith.addi %arg8, %c31_i32 : i32
    %8 = arith.divsi %7, %c32_i32 : i32
    %9 = arith.muli %6, %8 : i32
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    %10 = arith.muli %8, %c6_i32 : i32
    %11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
    %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
    %22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
    %23 = arith.addi %arg9, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
    %26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
    %27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
    %29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
    %31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
    %32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
    %33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
    scf.for %arg18 = %4 to %9 step %c288_i32  : i32 {
      %36 = arith.divsi %arg18, %10 : i32
      %37 = arith.muli %36, %c6_i32 : i32
      %38 = arith.subi %6, %37 : i32
      %39 = arith.minsi %38, %c6_i32 : i32
      %40 = arith.remsi %arg18, %10 : i32
      %41 = arith.remsi %40, %39 : i32
      %42 = arith.addi %37, %41 : i32
      %43 = arith.divsi %40, %39 : i32
      %44 = arith.muli %42, %c32_i32 : i32
      %45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %51 = arith.muli %43, %c32_i32 : i32
      %52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
      %59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
      %60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
      %61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
      %62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
      %64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
      %66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
      %67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
      %68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
      %69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
      %70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
      %71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
      %72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>)  : i32 {
        %91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
        %92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
        %93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
        %94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
        %95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
        %96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
        %97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
        %98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
        %99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
        %100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
        %101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
        %102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
        %103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
        scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
      } {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
      %73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
      %74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
      %75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
      %76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
      %77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
      %78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
      %79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
      %80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
      %81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
      %82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
      %83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
      %84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
      tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
      %85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
      %86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
      scf.while (%arg19 = %c0_i32) : (i32) -> () {
        %91 = arith.cmpi slt, %arg19, %c0_i32 : i32
        scf.condition(%91)
      } do {
        %91 = tt.load %arg15 : !tt.ptr<i64>
        %92 = arith.subi %86, %91 : i64
        %93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
        %94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
        %95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
        %96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
        scf.yield %96 : i32
      }
      %87 = arith.cmpi slt, %44, %arg7 : i32
      %88 = arith.addi %44, %c32_i32 : i32
      %89 = arith.cmpi sgt, %88, %c0_i32 : i32
      %90 = arith.andi %87, %89 : i1
      scf.if %90 {
        %91 = arith.subi %c0_i32, %44 : i32
        %92 = arith.maxsi %91, %c0_i32 : i32
        %93 = arith.subi %arg7, %44 : i32
        %94 = arith.minsi %93, %c32_i32 : i32
        %95 = arith.maxsi %44, %c0_i32 : i32
        %96 = arith.muli %44, %arg12 : i32
        %97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
        %98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
        %99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %101 = arith.addi %96, %51 : i32
        %102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
        %103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
        %104 = tt.load %arg15 : !tt.ptr<i64>
        %105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        %107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
        %108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
        %109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
        %110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
        %111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
        %112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
        %113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
        %115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
        %116 = arith.muli %95, %arg13 : i32
        %117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
        %118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
        %119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %120 = arith.addi %116, %51 : i32
        %121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
        %122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
        %123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
        %124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
        %125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
        %126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
        %127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
        %128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
        %129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
        %130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
        %131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
      }
    }
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype0] _____________

dtype = torch.float16, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype1] _____________

dtype = torch.bfloat16, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype2] _____________

dtype = torch.float32, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
>       run_experiment()

tests/examples/test_gemm_reduce_scatter_bench.py:154: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
    local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
    result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
    kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
    next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

src = <triton._C.libtriton.ir.module object at 0x7fa4cfa70590>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')

    @staticmethod
    def make_llir(src, metadata, options):
        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        # custom_lds_size is an experimental parameter that defines amount of LDS available
        # for one thread block. Measured in bytes.
        #
        # If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
        # LDS size is determined by provided arch name.
        custom_lds_size = 0
        amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
        passes.convert.add_triton_scf_to_cf(pm)
        passes.convert.add_index_to_llvmir(pm)
    
        amd.passes.ttgpuir.add_allocate_shared_memory(pm)
        # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
        ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
        ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
        ##    of the value of kernel arg `allow_flush_denorm`.
        ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
        ##    depends on the value of kernel arg `allow_flush_denorm`.
        ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
        ##    For now it is used as a controller for developers only.
        __HIP_FTZ = True
        amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
    
        passes.convert.add_cf_to_llvmir(pm)
        passes.convert.add_arith_to_llvmir(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
    
        if options.schedule_hint.lower() != "none":
            amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
    
        # This can not be moved below the di_scope pass
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
    
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
    
        amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
>       pm.run(mod)
E       RuntimeError: PassManager::run failed

/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
  tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
    %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
    %c32_i32 = arith.constant 32 : i32
    %c6_i32 = arith.constant 6 : i32
    %true = arith.constant true
    %c36_i32 = arith.constant 36 : i32
    %c8_i32 = arith.constant 8 : i32
    %c15_i32 = arith.constant 15 : i32
    %c16_i32 = arith.constant 16 : i32
    %c31_i32 = arith.constant 31 : i32
    %c1_i32 = arith.constant 1 : i32
    %c288_i32 = arith.constant 288 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.remsi %0, %c8_i32 : i32
    %2 = arith.muli %1, %c36_i32 : i32
    %3 = arith.divsi %0, %c8_i32 : i32
    %4 = arith.addi %2, %3 : i32
    %5 = arith.addi %arg7, %c31_i32 : i32
    %6 = arith.divsi %5, %c32_i32 : i32
    %7 = arith.addi %arg8, %c31_i32 : i32
    %8 = arith.divsi %7, %c32_i32 : i32
    %9 = arith.muli %6, %8 : i32
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    %10 = arith.muli %8, %c6_i32 : i32
    %11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
    %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
    %22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
    %23 = arith.addi %arg9, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
    %26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
    %27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
    %29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
    %31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
    %32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
    %33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
    scf.for %arg18 = %4 to %9 step %c288_i32  : i32 {
      %36 = arith.divsi %arg18, %10 : i32
      %37 = arith.muli %36, %c6_i32 : i32
      %38 = arith.subi %6, %37 : i32
      %39 = arith.minsi %38, %c6_i32 : i32
      %40 = arith.remsi %arg18, %10 : i32
      %41 = arith.remsi %40, %39 : i32
      %42 = arith.addi %37, %41 : i32
      %43 = arith.divsi %40, %39 : i32
      %44 = arith.muli %42, %c32_i32 : i32
      %45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %51 = arith.muli %43, %c32_i32 : i32
      %52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
      %59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
      %60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
      %61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
      %62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
      %64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
      %66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
      %67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
      %68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
      %69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
      %70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
      %71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
      %72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>)  : i32 {
        %91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
        %92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
        %93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
        %94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
        %95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
        %96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
        %97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
        %98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
        %99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
        %100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
        %101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
        %102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
        %103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
        scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
      } {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
      %73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
      %74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
      %75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
      %76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
      %77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
      %78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
      %79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
      %80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
      %81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
      %82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
      %83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
      %84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
      tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
      %85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
      %86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
      scf.while (%arg19 = %c0_i32) : (i32) -> () {
        %91 = arith.cmpi slt, %arg19, %c0_i32 : i32
        scf.condition(%91)
      } do {
        %91 = tt.load %arg15 : !tt.ptr<i64>
        %92 = arith.subi %86, %91 : i64
        %93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
        %94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
        %95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
        %96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
        scf.yield %96 : i32
      }
      %87 = arith.cmpi slt, %44, %arg7 : i32
      %88 = arith.addi %44, %c32_i32 : i32
      %89 = arith.cmpi sgt, %88, %c0_i32 : i32
      %90 = arith.andi %87, %89 : i1
      scf.if %90 {
        %91 = arith.subi %c0_i32, %44 : i32
        %92 = arith.maxsi %91, %c0_i32 : i32
        %93 = arith.subi %arg7, %44 : i32
        %94 = arith.minsi %93, %c32_i32 : i32
        %95 = arith.maxsi %44, %c0_i32 : i32
        %96 = arith.muli %44, %arg12 : i32
        %97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
        %98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
        %99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %101 = arith.addi %96, %51 : i32
        %102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
        %103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
        %104 = tt.load %arg15 : !tt.ptr<i64>
        %105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        %107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
        %108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
        %109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
        %110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
        %111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
        %112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
        %113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
        %115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
        %116 = arith.muli %95, %arg13 : i32
        %117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
        %118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
        %119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %120 = arith.addi %116, %51 : i32
        %121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
        %122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
        %123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
        %124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
        %125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
        %126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
        %127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
        %128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
        %129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
        %130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
        %131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
      }
    }
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype0] _____________

dtype = torch.float16, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype1] _____________

dtype = torch.bfloat16, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype2] _____________

dtype = torch.float32, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
>       run_experiment()

tests/examples/test_gemm_reduce_scatter_bench.py:154: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
    local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
    result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
    kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
    next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

src = <triton._C.libtriton.ir.module object at 0x7fa4cfae3c40>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')

    @staticmethod
    def make_llir(src, metadata, options):
        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        # custom_lds_size is an experimental parameter that defines amount of LDS available
        # for one thread block. Measured in bytes.
        #
        # If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
        # LDS size is determined by provided arch name.
        custom_lds_size = 0
        amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
        passes.convert.add_triton_scf_to_cf(pm)
        passes.convert.add_index_to_llvmir(pm)
    
        amd.passes.ttgpuir.add_allocate_shared_memory(pm)
        # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
        ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
        ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
        ##    of the value of kernel arg `allow_flush_denorm`.
        ## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
        ##    depends on the value of kernel arg `allow_flush_denorm`.
        ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
        ##    For now it is used as a controller for developers only.
        __HIP_FTZ = True
        amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
    
        passes.convert.add_cf_to_llvmir(pm)
        passes.convert.add_arith_to_llvmir(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
    
        if options.schedule_hint.lower() != "none":
            amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
    
        # This can not be moved below the di_scope pass
        if HIPBackend.instrumentation:
            HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
    
        if not knobs.compilation.disable_line_info:
            passes.llvmir.add_di_scope(pm)
    
        amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
>       pm.run(mod)
E       RuntimeError: PassManager::run failed

/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
  tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
    %cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
    %c32_i32 = arith.constant 32 : i32
    %c6_i32 = arith.constant 6 : i32
    %true = arith.constant true
    %c36_i32 = arith.constant 36 : i32
    %c8_i32 = arith.constant 8 : i32
    %c15_i32 = arith.constant 15 : i32
    %c16_i32 = arith.constant 16 : i32
    %c31_i32 = arith.constant 31 : i32
    %c1_i32 = arith.constant 1 : i32
    %c288_i32 = arith.constant 288 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.remsi %0, %c8_i32 : i32
    %2 = arith.muli %1, %c36_i32 : i32
    %3 = arith.divsi %0, %c8_i32 : i32
    %4 = arith.addi %2, %3 : i32
    %5 = arith.addi %arg7, %c31_i32 : i32
    %6 = arith.divsi %5, %c32_i32 : i32
    %7 = arith.addi %arg8, %c31_i32 : i32
    %8 = arith.divsi %7, %c32_i32 : i32
    %9 = arith.muli %6, %8 : i32
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    %10 = arith.muli %8, %c6_i32 : i32
    %11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
    %18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
    %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
    %22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
    %23 = arith.addi %arg9, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
    %26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
    %27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
    %29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
    %31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
    %32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
    %33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
    %35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
    scf.for %arg18 = %4 to %9 step %c288_i32  : i32 {
      %36 = arith.divsi %arg18, %10 : i32
      %37 = arith.muli %36, %c6_i32 : i32
      %38 = arith.subi %6, %37 : i32
      %39 = arith.minsi %38, %c6_i32 : i32
      %40 = arith.remsi %arg18, %10 : i32
      %41 = arith.remsi %40, %39 : i32
      %42 = arith.addi %37, %41 : i32
      %43 = arith.divsi %40, %39 : i32
      %44 = arith.muli %42, %c32_i32 : i32
      %45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
      %50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
      %51 = arith.muli %43, %c32_i32 : i32
      %52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
      %57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
      %58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
      %59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
      %60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
      %61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
      %62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
      %64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
      %65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
      %66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
      %67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
      %68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
      %69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
      %70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
      %71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
      %72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>)  : i32 {
        %91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
        %92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
        %93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
        %94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
        %95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
        %96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
        %97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
        %98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
        %99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
        %100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
        %101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
        %102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
        %103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
        scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
      } {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
      %73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
      %74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
      %75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
      %76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
      %77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
      %78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
      %79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
      %80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
      %81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
      %82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
      %83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
      %84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
      tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
      %85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
      %86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
      scf.while (%arg19 = %c0_i32) : (i32) -> () {
        %91 = arith.cmpi slt, %arg19, %c0_i32 : i32
        scf.condition(%91)
      } do {
        %91 = tt.load %arg15 : !tt.ptr<i64>
        %92 = arith.subi %86, %91 : i64
        %93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
        %94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
        %95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
        %96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
        scf.yield %96 : i32
      }
      %87 = arith.cmpi slt, %44, %arg7 : i32
      %88 = arith.addi %44, %c32_i32 : i32
      %89 = arith.cmpi sgt, %88, %c0_i32 : i32
      %90 = arith.andi %87, %89 : i1
      scf.if %90 {
        %91 = arith.subi %c0_i32, %44 : i32
        %92 = arith.maxsi %91, %c0_i32 : i32
        %93 = arith.subi %arg7, %44 : i32
        %94 = arith.minsi %93, %c32_i32 : i32
        %95 = arith.maxsi %44, %c0_i32 : i32
        %96 = arith.muli %44, %arg12 : i32
        %97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
        %98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
        %99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %101 = arith.addi %96, %51 : i32
        %102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
        %103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
        %104 = tt.load %arg15 : !tt.ptr<i64>
        %105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        %107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
        %108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
        %109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
        %110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
        %111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
        %112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
        %113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
        %115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
        %116 = arith.muli %95, %arg13 : i32
        %117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
        %118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
        %119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
        %120 = arith.addi %116, %51 : i32
        %121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
        %122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
        %123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
        %124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
        %125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
        %126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
        %127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
        %128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
        %129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
        %130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
        %131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
        %132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
        tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
      }
    }
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[64-64-32-128-128-128-dtype1] _____________

dtype = torch.bfloat16, m = 128, n = 128, k = 128, BLK_M = 64, BLK_N = 64
BLK_K = 32

    @pytest.mark.parametrize(
        "dtype",
        [
            torch.float16,
            torch.bfloat16,
            torch.float32,
        ],
    )
    @pytest.mark.parametrize(
        "m, n, k",
        [
            (64, 64, 64),  # Very small for quick testing
            (128, 128, 128),  # Small
            (256, 256, 256),  # Medium
        ],
    )
    @pytest.mark.parametrize(
        "BLK_M, BLK_N, BLK_K",
        [
            (32, 32, 16),  # Small blocks
            (64, 64, 32),  # Medium blocks
        ],
    )
    def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
        """Worker function for PyTorch distributed execution."""
        heap_size = 1 << 30
        shmem = iris.iris(heap_size)
        rank = shmem.get_rank()
        world_size = shmem.get_num_ranks()
        cu_count = shmem.get_cu_count()
    
        # GEMM
        datatype = dtype
    
        assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
        assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
        A = shmem.randn(m, k, device="cuda", dtype=datatype)
        B = shmem.randn(n, k, device="cuda", dtype=datatype).T
        C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
        M = m
        N = n
        K = k
    
        # Splitting
        rows_per_gpu = k // world_size
        k = rows_per_gpu
        start_row = rank * rows_per_gpu
        end_row = start_row + rows_per_gpu
        local_B = B[start_row:end_row, :]
        local_A = A[:, start_row:end_row]
    
        compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
        local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
        total_blocks_M = triton.cdiv(m, BLK_M)
        total_blocks_N = triton.cdiv(n, BLK_N)
        total_tiles = total_blocks_M * total_blocks_N
    
        tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    
        locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
        P = shmem.zeros(
            (288, BLK_M * BLK_N),
            device="cuda",
            dtype=torch.float32,
        )
        bias = None
        gemm_stream = torch.cuda.Stream()
        timestamps = Timestamps(num_tiles=total_tiles)
    
        def preamble():
            shmem.barrier()
            tile_completed.zero_()
            shmem.barrier()
    
        def run_experiment():
            nonlocal local_output
            nonlocal compute_buffer
    
            shmem.barrier()
    
            torch.cuda.nvtx.range_push("GEMM + Communication")
            with torch.cuda.stream(gemm_stream):
                local_output = matmul_module.matmul_reduce_scatter.apply(
                    local_A,
                    local_B,
                    compute_buffer,
                    local_output,
                    bias,
                    P,
                    locks,
                    tile_completed,
                    rank,
                    world_size,
                    288,
                    BLK_M,
                    BLK_N,
                    BLK_K,
                    6,
                    True,
                    1,
                    8,
                    0,
                    16,
                    2,
                    shmem.get_heap_bases(),
                    cu_count,
                    False,
                    timestamps.mm_begin_timestamp,
                    timestamps.mm_end_timestamp,
                )
            torch.cuda.nvtx.range_pop()
            shmem.barrier()
    
        # Synchronize across all GPUs
        shmem.barrier()
        run_experiment()
        shmem.barrier()
        preamble()
        shmem.barrier()
    
        shmem.info("Validating...")
    
        matmul_module.matmul_reduce_scatter.set_debug(False)
        # Validate global result
        success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
>       assert success, (
            f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
        )
E       AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=64, BLK_N=64, BLK_K=32
E       assert False

tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO     iris::0 Validating...
=========================== short test summary info ============================
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[64-64-32-128-128-128-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=64, BLK_N=64, BLK_K=32
assert False
========================= 10 failed, 8 passed in 1.66s =========================

@neoblizz neoblizz changed the title add Gemm ReduceScatter Atomic-based GEMM + ReduceScatter Sep 18, 2025
@danielhua23
Copy link
Author

Hi @mawad-amd, seems all checks passed, we can merge it if you don't have comments, and we can further optimize the perf and support more patterns in another PRs

@mawad-amd
Copy link
Collaborator

@danielhua23, I am sorry for the false positives. I am fixing them now in a different PR but there are still some tests failing. See here and here and other ranks too. Could you please check again these failing sizes? PR looks good to me otherwise.

@mawad-amd
Copy link
Collaborator

I fixed the CI issues. Please feel free to disable tile sizes that are not suitable for the problems.

@danielhua23
Copy link
Author

danielhua23 commented Sep 19, 2025

Hi @mawad-amd , currently it's weird that all world size=2 cases are crashed, when I run benchmark.py, the worldsize=2 work well, see below, but when run test_gemm_rs_bench.py in worldsize=2 case, it crashed. This troubled me and I can't figure it out, Could you pls help take a look? Thanks a ton!

1.worldsize=2, my local test_gemm_rs_bench.py reported crash as well

.....matmul and validate module import....
def test_gemm_reduce_scatter(local_rank, world_size, init_url, dtype, m, n, k, BLK_M, BLK_N, BLK_K):
    """Worker function for PyTorch distributed execution."""
    dist.init_process_group(backend="nccl", init_method=init_url, world_size=world_size, rank=local_rank)
    heap_size = 1 << 30
    shmem = iris.iris(heap_size)
    rank = shmem.get_rank()
    world_size = shmem.get_num_ranks()
    cu_count = shmem.get_cu_count()
    # GEMM
    datatype = dtype
    assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
    assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
    
    A = shmem.randn(m, k, device="cuda", dtype=datatype)
    B = shmem.randn(n, k, device="cuda", dtype=datatype).T
    C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    
    M = m
    N = n
    K = k
    # Splitting
    rows_per_gpu = k // world_size
    k = rows_per_gpu
    start_row = rank * rows_per_gpu
    end_row = start_row + rows_per_gpu
    local_B = B[start_row:end_row, :]
    local_A = A[:, start_row:end_row]

    compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
    local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
    
    total_blocks_M = triton.cdiv(m, BLK_M)
    total_blocks_N = triton.cdiv(n, BLK_N)
    total_tiles = total_blocks_M * total_blocks_N
    tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
    locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
    P = shmem.zeros(
        (288, BLK_M * BLK_N),
        device="cuda",
        dtype=torch.float32,
    )
    bias = None
    gemm_stream = torch.cuda.Stream()
    timestamps = Timestamps(num_tiles=total_tiles)

    def preamble():
        shmem.barrier()
        tile_completed.zero_()
        shmem.barrier()
       
    def run_experiment():
        nonlocal local_output
        nonlocal compute_buffer
        shmem.barrier()
        torch.cuda.nvtx.range_push("GEMM + Communication")
        with torch.cuda.stream(gemm_stream):
            local_output = matmul_module.matmul_reduce_scatter.apply(
                local_A,
                local_B,
                compute_buffer,
                local_output,
                bias,
                P,
                locks,
                tile_completed,
                rank,
                world_size,
                288,
                BLK_M,
                BLK_N,
                BLK_K,
                6,
                True,
                1,
                8,
                0,
                16,
                2,
                shmem.get_heap_bases(),
                cu_count,
                False,
                timestamps.mm_begin_timestamp,
                timestamps.mm_end_timestamp,
            )
        torch.cuda.nvtx.range_pop()
        shmem.barrier()
    # Synchronize across all GPUs
    shmem.barrier()
    run_experiment()
    shmem.barrier()
    preamble()
    shmem.barrier()
    
    shmem.info("Validating...")

    matmul_module.matmul_reduce_scatter.set_debug(False)
    # Validate global result
    success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
    assert success, (
        f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
    )
    
def main():
    num_ranks = 2

    init_url = "tcp://127.0.0.1:29500"
    dtype = torch.bfloat16

    m, n, k = (512, 512, 512)
    BLK_M, BLK_N, BLK_K = (64, 64, 32)
    mp.spawn(
        fn=test_gemm_reduce_scatter,
        args=(num_ranks, init_url, dtype, m, n, k, BLK_M, BLK_N, BLK_K),
        nprocs=num_ranks,
        join=True,
    )
if __name__ == "__main__":
    main()
  1. worldsize=2, run benchmark.py
root@mi300x8-:/workspace/iris/examples/13_gemm_reduce_scatter# python benchmark.py -v --datatype bf16 -r 2 --BLK_M 64 --BLK_N 64 --BLK_K 32 -m 512 -n 512 -k 512
M,N,K=512,512,256 ; BLK_M,N,K=64,64,32
Rank 1/2 responsible for 256 rows
total_blocks_M=8 x total_blocks_N=8 = total_tiles=64
total_tiles_streamk=64 + total_blocking_tiles=0 = total_tiles=64
M,N,K=512,512,256 ; BLK_M,N,K=64,64,32
total_programs_streamk=288
Rank 0/2 responsible for 256 rows
total_blocks_M=8 x total_blocks_N=8 = total_tiles=64
total_tiles_streamk=64 + total_blocking_tiles=0 = total_tiles=64
total_programs_streamk=288
69 registers used, 0 spills
68 registers used, 0 spills
[Iris] [1/2] Validating...
[Iris] [0/2] Validating...
[Iris] [1/2] Final C validation **passed**.
[Iris] [0/2] Final C validation **passed**.
[Iris] [0/2] Validation completed 

@mawad-amd
Copy link
Collaborator

Sounds good. I will take a look.

@hebais
Copy link

hebais commented Sep 24, 2025

Hello, it seems the idea is different from Triton-Distributed example,this example's gemm kernel seems also fused the final reduce part. Have you ever considered the implementation using 2 streams for the gemm/reduce overlap ?

@mawad-amd
Copy link
Collaborator

Hi @hebais! Great observation. The implementation in this PR is a fused sequential implementation. There are different ways to do atomic reduction (e.g., atomics, one shot, ring) and there are different ways to achieve compute and communication overlap. See the chart below for the taxonomy we came up with. The two stream solution would be the unfused producer-consumer. Iris goal is to facilitate implementing all of these variants. FWIW, we implemented all the variants for GEMM + AllScatter (see Workgroup Specialization, Producer-Consumer, and Bulk-Synchronous, or Fused-Sequential). There are generally different tradeoffs between the techniques and I recommend you watch our talk here.

If you’d like to implement the producer–consumer pattern with Iris, you can start from this and feel free to open draft PRs—we’d be happy to review and support you along the way! Also see this video where @neoblizz implement all variants for GEMM + AllScatter.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: GEMM + ReduceScatter
3 participants