-
Notifications
You must be signed in to change notification settings - Fork 17
Atomic-based GEMM + ReduceScatter #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Daniel, for your contributions. First pass and this is looking mostly good but could you please add a test similar to
https://github.com/ROCm/iris/blob/main/tests/examples/test_all_load_bench.py
You will need to call the _worker
function directly and pass the test arguments and you will need to modify the _worker
to return local_output
so you can validate that at the call site in the test.
@danielhua23 I just tested this and some tests are passing and some are not. I am not sure if all the configs are possible and/or if we need to increase the tolerance for some data types: Details
Running tests/examples/test_gemm_reduce_scatter_bench.py with 1 ranks
args=['tests/examples/test_gemm_reduce_scatter_bench.py'], test_file=tests/examples/test_gemm_reduce_scatter_bench.py, pytest_args=[]
============================= test session starts ==============================
platform linux -- Python 3.10.18, pytest-7.3.2, pluggy-1.5.0
rootdir: /work1/amd/muhaawad/git/amd/pdp/iris
plugins: rerunfailures-14.0, xdoctest-1.2.0, hypothesis-5.35.1, xdist-3.3.1, cpp-2.3.0, anyio-4.10.0, flakefinder-1.1.0, typeguard-4.3.0
collected 18 items
tests/examples/test_gemm_reduce_scatter_bench.py FFFFFFFFF....F.... [100%]
=================================== FAILURES ===================================
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype0] ______________
dtype = torch.float16, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stdout call -----------------------------
M,N,K=64,64,64 ; BLK_M,N,K=32,32,16
Rank 0/1 responsible for 64 rows
total_blocks_M=2 x total_blocks_N=2 = total_tiles=4
total_tiles_streamk=4 + total_blocking_tiles=0 = total_tiles=4
total_programs_streamk=288
32 registers used, 0 spills
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype1] ______________
dtype = torch.bfloat16, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
______________ test_gemm_reduce_scatter[32-32-16-64-64-64-dtype2] ______________
dtype = torch.float32, m = 64, n = 64, k = 64, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
> run_experiment()
tests/examples/test_gemm_reduce_scatter_bench.py:154:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src = <triton._C.libtriton.ir.module object at 0x7faccfc95580>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')
@staticmethod
def make_llir(src, metadata, options):
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_triton_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.convert.add_cf_to_llvmir(pm)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if options.schedule_hint.lower() != "none":
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
# This can not be moved below the di_scope pass
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
> pm.run(mod)
E RuntimeError: PassManager::run failed
/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
%c32_i32 = arith.constant 32 : i32
%c6_i32 = arith.constant 6 : i32
%true = arith.constant true
%c36_i32 = arith.constant 36 : i32
%c8_i32 = arith.constant 8 : i32
%c15_i32 = arith.constant 15 : i32
%c16_i32 = arith.constant 16 : i32
%c31_i32 = arith.constant 31 : i32
%c1_i32 = arith.constant 1 : i32
%c288_i32 = arith.constant 288 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%0 = tt.get_program_id x : i32
%1 = arith.remsi %0, %c8_i32 : i32
%2 = arith.muli %1, %c36_i32 : i32
%3 = arith.divsi %0, %c8_i32 : i32
%4 = arith.addi %2, %3 : i32
%5 = arith.addi %arg7, %c31_i32 : i32
%6 = arith.divsi %5, %c32_i32 : i32
%7 = arith.addi %arg8, %c31_i32 : i32
%8 = arith.divsi %7, %c32_i32 : i32
%9 = arith.muli %6, %8 : i32
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
%10 = arith.muli %8, %c6_i32 : i32
%11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
%22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%23 = arith.addi %arg9, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
%26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
%27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
%31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
%32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
%33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
%35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
scf.for %arg18 = %4 to %9 step %c288_i32 : i32 {
%36 = arith.divsi %arg18, %10 : i32
%37 = arith.muli %36, %c6_i32 : i32
%38 = arith.subi %6, %37 : i32
%39 = arith.minsi %38, %c6_i32 : i32
%40 = arith.remsi %arg18, %10 : i32
%41 = arith.remsi %40, %39 : i32
%42 = arith.addi %37, %41 : i32
%43 = arith.divsi %40, %39 : i32
%44 = arith.muli %42, %c32_i32 : i32
%45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%51 = arith.muli %43, %c32_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
%59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
%60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
%61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
%62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
%64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
%66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
%67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
%68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
%69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
%70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
%72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>) : i32 {
%91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
%92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
%93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
%94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
%95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
%96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
%97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
%98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
%100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
%102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
%103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
} {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
%73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
%74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
%75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
%76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
%77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
%78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
%79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
%80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
%81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
%82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
%83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
%84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
%85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
%86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
scf.while (%arg19 = %c0_i32) : (i32) -> () {
%91 = arith.cmpi slt, %arg19, %c0_i32 : i32
scf.condition(%91)
} do {
%91 = tt.load %arg15 : !tt.ptr<i64>
%92 = arith.subi %86, %91 : i64
%93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
%94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
%95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
%96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
scf.yield %96 : i32
}
%87 = arith.cmpi slt, %44, %arg7 : i32
%88 = arith.addi %44, %c32_i32 : i32
%89 = arith.cmpi sgt, %88, %c0_i32 : i32
%90 = arith.andi %87, %89 : i1
scf.if %90 {
%91 = arith.subi %c0_i32, %44 : i32
%92 = arith.maxsi %91, %c0_i32 : i32
%93 = arith.subi %arg7, %44 : i32
%94 = arith.minsi %93, %c32_i32 : i32
%95 = arith.maxsi %44, %c0_i32 : i32
%96 = arith.muli %44, %arg12 : i32
%97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
%98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
%99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
%101 = arith.addi %96, %51 : i32
%102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
%103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
%104 = tt.load %arg15 : !tt.ptr<i64>
%105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
%107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
%108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
%109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
%110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
%111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
%112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
%113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
%115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
%116 = arith.muli %95, %arg13 : i32
%117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
%118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
%119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%120 = arith.addi %116, %51 : i32
%121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
%122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
%123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
%124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
%125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
%126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
%127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
%128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
%130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
%131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
}
}
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
disable_threading: false,
verify_each: true
}
}
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype0] _____________
dtype = torch.float16, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype1] _____________
dtype = torch.bfloat16, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-128-128-128-dtype2] _____________
dtype = torch.float32, m = 128, n = 128, k = 128, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
> run_experiment()
tests/examples/test_gemm_reduce_scatter_bench.py:154:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src = <triton._C.libtriton.ir.module object at 0x7fa4cfa70590>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')
@staticmethod
def make_llir(src, metadata, options):
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_triton_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.convert.add_cf_to_llvmir(pm)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if options.schedule_hint.lower() != "none":
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
# This can not be moved below the di_scope pass
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
> pm.run(mod)
E RuntimeError: PassManager::run failed
/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
%c32_i32 = arith.constant 32 : i32
%c6_i32 = arith.constant 6 : i32
%true = arith.constant true
%c36_i32 = arith.constant 36 : i32
%c8_i32 = arith.constant 8 : i32
%c15_i32 = arith.constant 15 : i32
%c16_i32 = arith.constant 16 : i32
%c31_i32 = arith.constant 31 : i32
%c1_i32 = arith.constant 1 : i32
%c288_i32 = arith.constant 288 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%0 = tt.get_program_id x : i32
%1 = arith.remsi %0, %c8_i32 : i32
%2 = arith.muli %1, %c36_i32 : i32
%3 = arith.divsi %0, %c8_i32 : i32
%4 = arith.addi %2, %3 : i32
%5 = arith.addi %arg7, %c31_i32 : i32
%6 = arith.divsi %5, %c32_i32 : i32
%7 = arith.addi %arg8, %c31_i32 : i32
%8 = arith.divsi %7, %c32_i32 : i32
%9 = arith.muli %6, %8 : i32
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
%10 = arith.muli %8, %c6_i32 : i32
%11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
%22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%23 = arith.addi %arg9, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
%26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
%27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
%31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
%32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
%33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
%35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
scf.for %arg18 = %4 to %9 step %c288_i32 : i32 {
%36 = arith.divsi %arg18, %10 : i32
%37 = arith.muli %36, %c6_i32 : i32
%38 = arith.subi %6, %37 : i32
%39 = arith.minsi %38, %c6_i32 : i32
%40 = arith.remsi %arg18, %10 : i32
%41 = arith.remsi %40, %39 : i32
%42 = arith.addi %37, %41 : i32
%43 = arith.divsi %40, %39 : i32
%44 = arith.muli %42, %c32_i32 : i32
%45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%51 = arith.muli %43, %c32_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
%59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
%60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
%61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
%62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
%64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
%66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
%67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
%68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
%69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
%70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
%72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>) : i32 {
%91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
%92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
%93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
%94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
%95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
%96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
%97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
%98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
%100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
%102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
%103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
} {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
%73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
%74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
%75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
%76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
%77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
%78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
%79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
%80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
%81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
%82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
%83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
%84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
%85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
%86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
scf.while (%arg19 = %c0_i32) : (i32) -> () {
%91 = arith.cmpi slt, %arg19, %c0_i32 : i32
scf.condition(%91)
} do {
%91 = tt.load %arg15 : !tt.ptr<i64>
%92 = arith.subi %86, %91 : i64
%93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
%94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
%95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
%96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
scf.yield %96 : i32
}
%87 = arith.cmpi slt, %44, %arg7 : i32
%88 = arith.addi %44, %c32_i32 : i32
%89 = arith.cmpi sgt, %88, %c0_i32 : i32
%90 = arith.andi %87, %89 : i1
scf.if %90 {
%91 = arith.subi %c0_i32, %44 : i32
%92 = arith.maxsi %91, %c0_i32 : i32
%93 = arith.subi %arg7, %44 : i32
%94 = arith.minsi %93, %c32_i32 : i32
%95 = arith.maxsi %44, %c0_i32 : i32
%96 = arith.muli %44, %arg12 : i32
%97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
%98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
%99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
%101 = arith.addi %96, %51 : i32
%102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
%103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
%104 = tt.load %arg15 : !tt.ptr<i64>
%105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
%107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
%108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
%109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
%110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
%111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
%112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
%113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
%115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
%116 = arith.muli %95, %arg13 : i32
%117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
%118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
%119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%120 = arith.addi %116, %51 : i32
%121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
%122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
%123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
%124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
%125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
%126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
%127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
%128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
%130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
%131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
}
}
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
disable_threading: false,
verify_each: true
}
}
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype0] _____________
dtype = torch.float16, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype1] _____________
dtype = torch.bfloat16, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
____________ test_gemm_reduce_scatter[32-32-16-256-256-256-dtype2] _____________
dtype = torch.float32, m = 256, n = 256, k = 256, BLK_M = 32, BLK_N = 32
BLK_K = 16
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
> run_experiment()
tests/examples/test_gemm_reduce_scatter_bench.py:154:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/examples/test_gemm_reduce_scatter_bench.py:121: in run_experiment
local_output = matmul_module.matmul_reduce_scatter.apply(
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py:574: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
examples/13_gemm_reduce_scatter/matmul_wrapper.py:179: in forward
result = matmul_reduce_scatter._call(
examples/13_gemm_reduce_scatter/matmul_wrapper.py:99: in _call
kk = gemm_kernel[(grids,)](
/workspace/triton/python/triton/runtime/jit.py:393: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/workspace/triton/python/triton/runtime/jit.py:599: in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
/workspace/triton/python/triton/runtime/jit.py:782: in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
/workspace/triton/python/triton/compiler/compiler.py:322: in compile
next_module = compile_ir(module, metadata)
/workspace/triton/python/triton/backends/amd/compiler.py:449: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src = <triton._C.libtriton.ir.module object at 0x7fa4cfae3c40>
metadata = {'allow_flush_denorm': False, 'allowed_dot_input_precisions': ('ieee', 'tf32'), 'arch': 'gfx942', 'backend_name': 'hip', ...}
options = HIPOptions(num_warps=8, waves_per_eu=0, num_stages=1, num_ctas=1, extern_libs=(('ocml', '/workspace/triton/python/trit...flush_denorm=False, max_num_imprecise_acc_default=0, backend_name='hip', instrumentation_mode='', schedule_hint='none')
@staticmethod
def make_llir(src, metadata, options):
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_triton_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
amd.passes.ttgpuir.add_allocate_shared_memory(pm)
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
## 2. If __HIP_FTZ = 0, whether exp2 flushes denorms in input and output
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.convert.add_cf_to_llvmir(pm)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if options.schedule_hint.lower() != "none":
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
# This can not be moved below the di_scope pass
if HIPBackend.instrumentation:
HIPBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
> pm.run(mod)
E RuntimeError: PassManager::run failed
/workspace/triton/python/triton/backends/amd/compiler.py:324: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /root/.triton/llvm/llvm-57088512-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = mlir::Value; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = mlir::Value&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 2], instrShape = [16, 16], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @persistent_gemm_reduce_scatter(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg16: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg17: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<32> : tensor<1x32xi32, #blocked>
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
%c32_i32 = arith.constant 32 : i32
%c6_i32 = arith.constant 6 : i32
%true = arith.constant true
%c36_i32 = arith.constant 36 : i32
%c8_i32 = arith.constant 8 : i32
%c15_i32 = arith.constant 15 : i32
%c16_i32 = arith.constant 16 : i32
%c31_i32 = arith.constant 31 : i32
%c1_i32 = arith.constant 1 : i32
%c288_i32 = arith.constant 288 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%0 = tt.get_program_id x : i32
%1 = arith.remsi %0, %c8_i32 : i32
%2 = arith.muli %1, %c36_i32 : i32
%3 = arith.divsi %0, %c8_i32 : i32
%4 = arith.addi %2, %3 : i32
%5 = arith.addi %arg7, %c31_i32 : i32
%6 = arith.divsi %5, %c32_i32 : i32
%7 = arith.addi %arg8, %c31_i32 : i32
%8 = arith.divsi %7, %c32_i32 : i32
%9 = arith.muli %6, %8 : i32
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
%10 = arith.muli %8, %c6_i32 : i32
%11 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%12 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%13 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%15 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%16 = tt.splat %arg7 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%17 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%18 = tt.splat %arg8 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%19 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%20 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi32, #blocked2>
%22 = tt.broadcast %21 : tensor<16x1xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%23 = arith.addi %arg9, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #mma>
%26 = tt.splat %arg8 : i32 -> tensor<1x32xi32, #mma>
%27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
%31 = arith.cmpi slt, %28, %cst_0 : tensor<32x1xi32, #blocked>
%32 = arith.cmpi slt, %30, %cst : tensor<1x32xi32, #blocked>
%33 = tt.broadcast %31 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%34 = tt.broadcast %32 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
%35 = arith.andi %33, %34 : tensor<32x32xi1, #blocked>
scf.for %arg18 = %4 to %9 step %c288_i32 : i32 {
%36 = arith.divsi %arg18, %10 : i32
%37 = arith.muli %36, %c6_i32 : i32
%38 = arith.subi %6, %37 : i32
%39 = arith.minsi %38, %c6_i32 : i32
%40 = arith.remsi %arg18, %10 : i32
%41 = arith.remsi %40, %39 : i32
%42 = arith.addi %37, %41 : i32
%43 = arith.divsi %40, %39 : i32
%44 = arith.muli %42, %c32_i32 : i32
%45 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%46 = tt.splat %44 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%47 = arith.addi %45, %11 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%48 = arith.addi %46, %12 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%49 = arith.remsi %47, %15 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%50 = arith.remsi %48, %16 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%51 = arith.muli %43, %c32_i32 : i32
%52 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%53 = tt.splat %51 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%54 = arith.addi %52, %13 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%55 = arith.addi %53, %14 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%56 = arith.remsi %54, %17 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%57 = arith.remsi %55, %18 {tt.contiguity = dense<32> : tensor<1xi32>, tt.divisibility = dense<32> : tensor<1xi32>} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%58 = tt.expand_dims %50 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
%59 = tt.expand_dims %49 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
%60 = tt.splat %arg10 : i32 -> tensor<32x1xi32, #blocked1>
%61 = arith.muli %59, %60 : tensor<32x1xi32, #blocked1>
%62 = tt.broadcast %61 : tensor<32x1xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%63 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1>
%64 = tt.broadcast %63 : tensor<1x16xi32, #blocked1> -> tensor<32x16xi32, #blocked1>
%65 = arith.addi %64, %62 : tensor<32x16xi32, #blocked1>
%66 = tt.expand_dims %57 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
%67 = tt.expand_dims %56 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
%68 = tt.splat %arg11 : i32 -> tensor<1x32xi32, #blocked2>
%69 = arith.muli %67, %68 : tensor<1x32xi32, #blocked2>
%70 = tt.broadcast %69 : tensor<1x32xi32, #blocked2> -> tensor<16x32xi32, #blocked2>
%71 = arith.addi %70, %22 : tensor<16x32xi32, #blocked2>
%72:3 = scf.for %arg19 = %c0_i32 to %24 step %c1_i32 iter_args(%arg20 = %arg0, %arg21 = %cst_2, %arg22 = %arg1) -> (!tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>) : i32 {
%91 = tt.splat %arg20 : !tt.ptr<f32> -> tensor<32x16x!tt.ptr<f32>, #blocked1>
%92 = tt.addptr %91, %65 : tensor<32x16x!tt.ptr<f32>, #blocked1>, tensor<32x16xi32, #blocked1>
%93 = tt.load %92 : tensor<32x16x!tt.ptr<f32>, #blocked1>
%94 = tt.splat %arg22 : !tt.ptr<f32> -> tensor<16x32x!tt.ptr<f32>, #blocked2>
%95 = tt.addptr %94, %71 : tensor<16x32x!tt.ptr<f32>, #blocked2>, tensor<16x32xi32, #blocked2>
%96 = tt.load %95 : tensor<16x32x!tt.ptr<f32>, #blocked2>
%97 = ttg.local_alloc %93 : (tensor<32x16xf32, #blocked1>) -> !ttg.memdesc<32x16xf32, #shared, #smem>
%98 = ttg.local_load %97 : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%99 = ttg.local_alloc %96 : (tensor<16x32xf32, #blocked2>) -> !ttg.memdesc<16x32xf32, #shared1, #smem>
%100 = ttg.local_load %99 : !ttg.memdesc<16x32xf32, #shared1, #smem> -> tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%101 = tt.dot %98, %100, %arg21, inputPrecision = tf32 : tensor<32x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
%102 = tt.addptr %arg20, %c16_i32 : !tt.ptr<f32>, i32
%103 = tt.addptr %arg22, %c16_i32 : !tt.ptr<f32>, i32
scf.yield %102, %101, %103 : !tt.ptr<f32>, tensor<32x32xf32, #mma>, !tt.ptr<f32>
} {tt.divisibility_arg1 = dense<[1, 16]> : tensor<2xi32>, tt.divisibility_arg2 = dense<[1, 16]> : tensor<2xi32>}
%73 = arith.cmpi slt, %58, %25 : tensor<32x1xi32, #mma>
%74 = arith.cmpi slt, %66, %26 : tensor<1x32xi32, #mma>
%75 = tt.broadcast %73 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
%76 = tt.broadcast %74 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
%77 = arith.andi %75, %76 : tensor<32x32xi1, #mma>
%78 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #mma>
%79 = arith.muli %58, %78 : tensor<32x1xi32, #mma>
%80 = tt.broadcast %79 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
%81 = tt.broadcast %66 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
%82 = arith.addi %81, %80 : tensor<32x32xi32, #mma>
%83 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #mma>
%84 = tt.addptr %83, %82 : tensor<32x32x!tt.ptr<f32>, #mma>, tensor<32x32xi32, #mma>
tt.store %84, %72#1, %77 : tensor<32x32x!tt.ptr<f32>, #mma>
%85 = tt.addptr %arg6, %arg18 : !tt.ptr<i32>, i32
%86 = tt.ptr_to_int %85 : !tt.ptr<i32> -> i64
scf.while (%arg19 = %c0_i32) : (i32) -> () {
%91 = arith.cmpi slt, %arg19, %c0_i32 : i32
scf.condition(%91)
} do {
%91 = tt.load %arg15 : !tt.ptr<i64>
%92 = arith.subi %86, %91 : i64
%93 = tt.int_to_ptr %91 : i64 -> !tt.ptr<i8>
%94 = tt.addptr %93, %92 : !tt.ptr<i8>, i64
%95 = tt.bitcast %94 : !tt.ptr<i8> -> !tt.ptr<i32>
%96 = tt.atomic_cas acquire, sys, %95, %c0_i32, %c0_i32 : (!tt.ptr<i32>, i32, i32) -> i32
scf.yield %96 : i32
}
%87 = arith.cmpi slt, %44, %arg7 : i32
%88 = arith.addi %44, %c32_i32 : i32
%89 = arith.cmpi sgt, %88, %c0_i32 : i32
%90 = arith.andi %87, %89 : i1
scf.if %90 {
%91 = arith.subi %c0_i32, %44 : i32
%92 = arith.maxsi %91, %c0_i32 : i32
%93 = arith.subi %arg7, %44 : i32
%94 = arith.minsi %93, %c32_i32 : i32
%95 = arith.maxsi %44, %c0_i32 : i32
%96 = arith.muli %44, %arg12 : i32
%97 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
%98 = arith.muli %28, %97 : tensor<32x1xi32, #blocked>
%99 = tt.broadcast %98 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%100 = tt.broadcast %30 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
%101 = arith.addi %96, %51 : i32
%102 = arith.addi %99, %100 : tensor<32x32xi32, #blocked>
%103 = tt.addptr %arg2, %101 : !tt.ptr<f32>, i32
%104 = tt.load %arg15 : !tt.ptr<i64>
%105 = tt.splat %103 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%106 = tt.addptr %105, %102 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
%107 = tt.ptr_to_int %106 : tensor<32x32x!tt.ptr<f32>, #blocked> -> tensor<32x32xi64, #blocked>
%108 = tt.splat %104 : i64 -> tensor<32x32xi64, #blocked>
%109 = arith.subi %107, %108 : tensor<32x32xi64, #blocked>
%110 = tt.int_to_ptr %104 : i64 -> !tt.ptr<i8>
%111 = tt.splat %110 : !tt.ptr<i8> -> tensor<32x32x!tt.ptr<i8>, #blocked>
%112 = tt.addptr %111, %109 : tensor<32x32x!tt.ptr<i8>, #blocked>, tensor<32x32xi64, #blocked>
%113 = tt.bitcast %112 : tensor<32x32x!tt.ptr<i8>, #blocked> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%114 = tt.load %113, %35 : tensor<32x32x!tt.ptr<f32>, #blocked>
%115 = arith.addf %114, %cst_1 : tensor<32x32xf32, #blocked>
%116 = arith.muli %95, %arg13 : i32
%117 = tt.splat %arg13 : i32 -> tensor<32x1xi32, #blocked>
%118 = arith.muli %28, %117 : tensor<32x1xi32, #blocked>
%119 = tt.broadcast %118 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%120 = arith.addi %116, %51 : i32
%121 = arith.addi %119, %100 : tensor<32x32xi32, #blocked>
%122 = tt.addptr %arg3, %120 : !tt.ptr<f32>, i32
%123 = tt.splat %92 : i32 -> tensor<32x1xi32, #blocked>
%124 = arith.cmpi sge, %28, %123 : tensor<32x1xi32, #blocked>
%125 = tt.splat %94 : i32 -> tensor<32x1xi32, #blocked>
%126 = arith.cmpi slt, %28, %125 : tensor<32x1xi32, #blocked>
%127 = arith.andi %124, %126 : tensor<32x1xi1, #blocked>
%128 = tt.broadcast %127 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%129 = arith.andi %128, %34 : tensor<32x32xi1, #blocked>
%130 = arith.andi %129, %35 : tensor<32x32xi1, #blocked>
%131 = tt.splat %122 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
%132 = tt.addptr %131, %121 : tensor<32x32x!tt.ptr<f32>, #blocked>, tensor<32x32xi32, #blocked>
tt.store %132, %115, %130 cacheModifier = wt : tensor<32x32x!tt.ptr<f32>, #blocked>
}
}
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx942}, triton-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx942 ftz=true}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
disable_threading: false,
verify_each: true
}
}
#-}
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: error: Failures have been detected while processing an MLIR pass pipeline
/work1/amd/muhaawad/git/amd/pdp/iris/tests/examples/../../examples/13_gemm_reduce_scatter/gemm_reduce_scatter.py:98:0: note: Pipeline failed while executing [`ConvertTritonAMDGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
____________ test_gemm_reduce_scatter[64-64-32-128-128-128-dtype1] _____________
dtype = torch.bfloat16, m = 128, n = 128, k = 128, BLK_M = 64, BLK_N = 64
BLK_K = 32
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"m, n, k",
[
(64, 64, 64), # Very small for quick testing
(128, 128, 128), # Small
(256, 256, 256), # Medium
],
)
@pytest.mark.parametrize(
"BLK_M, BLK_N, BLK_K",
[
(32, 32, 16), # Small blocks
(64, 64, 32), # Medium blocks
],
)
def test_gemm_reduce_scatter(dtype, m, n, k, BLK_M, BLK_N, BLK_K):
"""Worker function for PyTorch distributed execution."""
heap_size = 1 << 30
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()
# GEMM
datatype = dtype
assert m % world_size == 0, f"M ({m}) must be divisible by world size ({world_size})."
assert k % world_size == 0, f"K ({k}) must be divisible by world size ({world_size})."
A = shmem.randn(m, k, device="cuda", dtype=datatype)
B = shmem.randn(n, k, device="cuda", dtype=datatype).T
C = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
M = m
N = n
K = k
# Splitting
rows_per_gpu = k // world_size
k = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]
compute_buffer = shmem.zeros((m, n), device="cuda", dtype=A.dtype)
local_output = shmem.zeros((m // world_size, n), device="cuda", dtype=A.dtype)
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((288,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(288, BLK_M * BLK_N),
device="cuda",
dtype=torch.float32,
)
bias = None
gemm_stream = torch.cuda.Stream()
timestamps = Timestamps(num_tiles=total_tiles)
def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()
def run_experiment():
nonlocal local_output
nonlocal compute_buffer
shmem.barrier()
torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
local_output = matmul_module.matmul_reduce_scatter.apply(
local_A,
local_B,
compute_buffer,
local_output,
bias,
P,
locks,
tile_completed,
rank,
world_size,
288,
BLK_M,
BLK_N,
BLK_K,
6,
True,
1,
8,
0,
16,
2,
shmem.get_heap_bases(),
cu_count,
False,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
torch.cuda.nvtx.range_pop()
shmem.barrier()
# Synchronize across all GPUs
shmem.barrier()
run_experiment()
shmem.barrier()
preamble()
shmem.barrier()
shmem.info("Validating...")
matmul_module.matmul_reduce_scatter.set_debug(False)
# Validate global result
success = validation_module.validate_gemm_reduce_scatter(A, B, local_output, rank, world_size, shmem, atol=2)
> assert success, (
f"GEMM reduce-scatter validation failed for dtype={dtype}, m={m}, n={n}, k={k}, BLK_M={BLK_M}, BLK_N={BLK_N}, BLK_K={BLK_K}"
)
E AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=64, BLK_N=64, BLK_K=32
E assert False
tests/examples/test_gemm_reduce_scatter_bench.py:164: AssertionError
----------------------------- Captured stderr call -----------------------------
[Iris] [0/1] Validating...
------------------------------ Captured log call -------------------------------
INFO iris::0 Validating...
=========================== short test summary info ============================
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=64, n=64, k=64, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-64-64-64-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-128-128-128-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype0] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.float16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=256, n=256, k=256, BLK_M=32, BLK_N=32, BLK_K=16
assert False
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[32-32-16-256-256-256-dtype2] - RuntimeError: PassManager::run failed
FAILED tests/examples/test_gemm_reduce_scatter_bench.py::test_gemm_reduce_scatter[64-64-32-128-128-128-dtype1] - AssertionError: GEMM reduce-scatter validation failed for dtype=torch.bfloat16, m=128, n=128, k=128, BLK_M=64, BLK_N=64, BLK_K=32
assert False
========================= 10 failed, 8 passed in 1.66s =========================
|
Hi @mawad-amd, seems all checks passed, we can merge it if you don't have comments, and we can further optimize the perf and support more patterns in another PRs |
@danielhua23, I am sorry for the false positives. I am fixing them now in a different PR but there are still some tests failing. See here and here and other ranks too. Could you please check again these failing sizes? PR looks good to me otherwise. |
I fixed the CI issues. Please feel free to disable tile sizes that are not suitable for the problems. |
Hi @mawad-amd , currently it's weird that all world size=2 cases are crashed, when I run benchmark.py, the worldsize=2 work well, see below, but when run test_gemm_rs_bench.py in worldsize=2 case, it crashed. This troubled me and I can't figure it out, Could you pls help take a look? Thanks a ton! 1.worldsize=2, my local test_gemm_rs_bench.py reported crash as well
|
Sounds good. I will take a look. |
Hello, it seems the idea is different from Triton-Distributed example,this example's gemm kernel seems also fused the final reduce part. Have you ever considered the implementation using 2 streams for the gemm/reduce overlap ? |
Hi @hebais! Great observation. The implementation in this PR is a fused sequential implementation. There are different ways to do atomic reduction (e.g., atomics, one shot, ring) and there are different ways to achieve compute and communication overlap. See the chart below for the taxonomy we came up with. The two stream solution would be the unfused producer-consumer. Iris goal is to facilitate implementing all of these variants. FWIW, we implemented all the variants for GEMM + AllScatter (see Workgroup Specialization, Producer-Consumer, and Bulk-Synchronous, or Fused-Sequential). There are generally different tradeoffs between the techniques and I recommend you watch our talk here. If you’d like to implement the producer–consumer pattern with Iris, you can start from this and feel free to open draft PRs—we’d be happy to review and support you along the way! Also see this video where @neoblizz implement all variants for GEMM + AllScatter. ![]() |
Motivation
this gemm reduce scatter version partition K of A and B then do matmul, for C, local_output of each rank own (M/world size, N) after reduce scatter

Closes #105
Test Result
this is my log