Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 94 additions & 13 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases):


@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def load(pointer, to_rank, from_rank, heap_bases, mask=None, cache_modifier=None, volatile=False):
"""
Loads a value from the specified rank's memory location.

Expand All @@ -1530,12 +1530,28 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
data from the target memory location. If the from_rank and to_rank are the same,
this function performs a local load operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global load instruction. These affect cache usage across the CU,
L2, and last-level caches.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load.

Returns:
Block: The loaded value from the target memory location.
Expand All @@ -1550,12 +1566,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
>>> return data
"""
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
result = tl.load(translated_ptr, mask=mask, cache_modifier=cache_modifier, volatile=volatile)
return result


@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None):
"""
Writes data to the specified rank's memory location.

Expand All @@ -1564,13 +1580,25 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
the provided data to the target memory location. If the from_rank and to_rank are the same,
this function performs a local store operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global store instruction. These affect cache usage across the CU (L1),
L2, and last-level cache (LLC), following the CDNA ISA.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1585,11 +1613,21 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
>>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases)
"""
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
tl.store(translated_ptr, value, mask=mask)
tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier)


@triton.jit
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
def copy(
src_ptr,
dst_ptr,
from_rank,
to_rank,
cur_rank,
heap_bases,
mask=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the specified rank's memory into the destination rank's memory.
This function performs the transfer by translating src_ptr from the from_rank's address
Expand All @@ -1607,6 +1645,19 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand Down Expand Up @@ -1635,12 +1686,14 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = tl.load(translated_src, mask=mask)
tl.store(translated_dst, data, mask=mask)
data = tl.load(translated_src, mask=mask, cache_modifier=load_cache_modifier)
tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def get(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -1657,6 +1710,19 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1669,13 +1735,15 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)

data = tl.load(translated_from_ptr, mask=mask)
data = tl.load(translated_from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(to_ptr, data, mask=mask)
tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def put(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the current rank's local memory to the specified rank's memory.
This function performs a memory write operation by loading data from the current
Expand All @@ -1691,6 +1759,19 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1703,9 +1784,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases)

data = tl.load(from_ptr, mask=mask)
data = tl.load(from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(translated_to_ptr, data, mask=mask)
tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
Expand Down
104 changes: 104 additions & 0 deletions tests/unittests/test_copy_cache_modifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
import pytest
import iris
from itertools import product


@triton.jit
def copy_kernel(
data,
results,
cur_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
load_cache_modifier: tl.constexpr,
store_cache_modifier: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < BLOCK_SIZE

# Test copy with cache modifiers - copy from current rank to other ranks
for target_rank in range(num_ranks):
src_data = data + BLOCK_SIZE * cur_rank
dest_data = results + BLOCK_SIZE * target_rank
if load_cache_modifier is None and store_cache_modifier is None:
iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask)
elif load_cache_modifier is None:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
store_cache_modifier=store_cache_modifier,
)
elif store_cache_modifier is None:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
)
else:
iris.copy(
src_data + offsets,
dest_data + offsets,
cur_rank,
target_rank,
cur_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
store_cache_modifier=store_cache_modifier,
)


# Define cache modifiers for load and store operations
LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"]
STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"]


@pytest.mark.parametrize(
"load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS))
)
def test_copy_cache_modifiers(load_cache_modifier, store_cache_modifier):
"""Test copy operation with various cache modifiers"""
shmem = iris.iris(1 << 20)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
cur_rank = shmem.get_rank()

BLOCK_SIZE = 16
data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32)
base = cur_rank + num_ranks
for i in range(num_ranks):
data[i, :] = base * (i + 1)

results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32)
grid = lambda meta: (1,)
copy_kernel[grid](
data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier
)

shmem.barrier()

# Verify results - each rank should have copied its data to all ranks
for i in range(num_ranks):
expected_value = base * (cur_rank + 1)
assert torch.allclose(results[i], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32)), (
f"Mismatch at rank {cur_rank}, target {i} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}"
)
Loading
Loading