Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions examples/04_atomic_add/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ SPDX-License-Identifier: MIT
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
-->

# Load benchmark
# Atomic Add benchmark

Load benchmark using Iris.
Atomic Add benchmark using Iris.

## Usage

Expand All @@ -24,4 +24,14 @@ GPU 04 -> 15.58 15.68 15.21 15.32 769.53 15.67 15.58
GPU 05 -> 15.59 15.49 15.24 15.50 15.57 773.01 15.67 15.59
GPU 06 -> 15.41 15.41 15.15 15.06 15.50 15.67 778.30 15.58
GPU 07 -> 15.22 15.33 15.07 15.06 15.66 15.54 15.56 765.45
```
```

## Testing

The benchmark can be tested using pytest:

```terminal
pytest tests/examples/test_atomic_add_bench.py
```

This will run parametrized tests across different data types (int8, fp16, bf16, fp32), buffer sizes, and block sizes to ensure the atomic add functionality works correctly.
44 changes: 42 additions & 2 deletions examples/04_atomic_add/atomic_add_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"int32": torch.int32,
"int64": torch.int64,
"bf16": torch.bfloat16,
}
try:
Expand All @@ -65,7 +66,7 @@ def parse_args():
"--datatype",
type=str,
default="fp16",
choices=["fp16", "fp32", "int8", "bf16"],
choices=["fp16", "fp32", "int32", "int64", "bf16"],
help="Datatype of computation",
)
parser.add_argument("-z", "--buffer_size", type=int, default=1 << 32, help="Buffer Size")
Expand All @@ -82,6 +83,45 @@ def parse_args():
return vars(parser.parse_args())


def bench_atomic_add(
shmem,
source_rank,
destination_rank,
source_buffer,
result_buffer,
BLOCK_SIZE,
dtype,
verbose=False,
validate=False,
num_experiments=1,
num_warmup=0,
):
"""
Wrapper function for testing compatibility, follows the same signature as bench_load.
"""
# Convert dtype to string for args dict
dtype_str_map = {
torch.int32: "int32",
torch.int64: "int64",
torch.float16: "fp16",
torch.bfloat16: "bf16",
torch.float32: "fp32",
}
datatype_str = dtype_str_map.get(dtype, "fp16")

# Create args dict as expected by run_experiment
args = {
"datatype": datatype_str,
"block_size": BLOCK_SIZE,
"verbose": verbose,
"validate": validate,
"num_experiments": num_experiments,
"num_warmup": num_warmup,
}

return run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer)


def run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer):
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
Expand Down
63 changes: 63 additions & 0 deletions tests/examples/test_atomic_add_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import pytest
import torch
import triton
import triton.language as tl
import numpy as np
import iris

import importlib.util
from pathlib import Path

current_dir = Path(__file__).parent
file_path = (current_dir / "../../examples/04_atomic_add/atomic_add_bench.py").resolve()
module_name = "atomic_add_bench"
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)


@pytest.mark.parametrize(
"dtype",
[
torch.int32,
torch.int64,
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"buffer_size, heap_size",
[
((1 << 32), (1 << 33)),
],
)
@pytest.mark.parametrize(
"block_size",
[
512,
1024,
],
)
def test_atomic_add_bench(dtype, buffer_size, heap_size, block_size):
shmem = iris.iris(heap_size)
num_ranks = shmem.get_num_ranks()

bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32)
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
source_buffer = shmem.ones(buffer_size // element_size_bytes, dtype=dtype)
result_buffer = shmem.zeros_like(source_buffer)

shmem.barrier()

for source_rank in range(num_ranks):
for destination_rank in range(num_ranks):
bandwidth_gbps = module.bench_atomic_add(
shmem, source_rank, destination_rank, source_buffer, result_buffer, block_size, dtype
)
bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps
shmem.barrier()