Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d312f5e
Initial plan
Copilot Oct 6, 2025
c255dee
Add CUDA backend support with runtime detection
Copilot Oct 6, 2025
259afe6
Add unit tests and fix linting issues for CUDA backend
Copilot Oct 6, 2025
3f3e78e
Update README with CUDA backend documentation
Copilot Oct 6, 2025
c6a9aa3
Add CUDA backend example script
Copilot Oct 6, 2025
0ba8573
Fix linting issues in example script
Copilot Oct 6, 2025
28965e6
Switch to build-time backend configuration with --config-settings
Copilot Oct 6, 2025
0a5992d
Remove redundant line from cuda_backend_example.py docstring
Copilot Oct 6, 2025
698604c
Add build() function to setup.py for config_settings handling
Copilot Oct 6, 2025
fcfd0cc
Remove example and test files, rename cuda.py to _cuda.py
Copilot Oct 6, 2025
a162b9e
Replace env var approach with _backend_selected.py file
Copilot Oct 6, 2025
eaf980f
Remove get_rocm_version() function from _cuda.py
Copilot Oct 6, 2025
ee8a18a
Simplify backend loading and remove unused functions
Copilot Oct 6, 2025
cde3ab8
Remove build_backend.py and all environment variable support
Copilot Oct 6, 2025
96f462d
Revert README.md to original state
Copilot Oct 6, 2025
c64e4ed
Simplify to auto-detection only, remove build() hook
Copilot Oct 6, 2025
da69d1b
Consolidate backends into single hip.py with conditional branching
Copilot Oct 6, 2025
5f53ebe
Add `setuptools` requirements
mawad-amd Oct 8, 2025
f06972a
Name generic functions` gpu*`
mawad-amd Oct 8, 2025
caf6ed0
Add necessay conversion for RCCL
mawad-amd Oct 8, 2025
73b0835
Apply Ruff auto-fixes
github-actions[bot] Oct 8, 2025
bd42b7d
Merge branch 'main' into copilot/fix-4d83afe6-045c-4573-a6ec-6f6dd80f…
mawad-amd Oct 8, 2025
abaa853
Remove the git ignore comment
mawad-amd Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if __name__ == "__main__":
### Quick Installation

> [!NOTE]
> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, and Triton
> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, Triton, and setuptools>=61
For a quick installation directly from the repository:

Expand Down
190 changes: 141 additions & 49 deletions iris/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,89 +8,152 @@
import subprocess
import os

rt_path = "libamdhip64.so"
hip_runtime = ctypes.cdll.LoadLibrary(rt_path)
# Auto-detect backend
_is_amd_backend = True
try:
rt_path = "libamdhip64.so"
gpu_runtime = ctypes.cdll.LoadLibrary(rt_path)
except OSError:
try:
rt_path = "libcudart.so"
gpu_runtime = ctypes.cdll.LoadLibrary(rt_path)
_is_amd_backend = False
except OSError:
rt_path = "libamdhip64.so"
gpu_runtime = ctypes.cdll.LoadLibrary(rt_path)


def hip_try(err):
def gpu_try(err):
if err != 0:
hip_runtime.hipGetErrorString.restype = ctypes.c_char_p
error_string = hip_runtime.hipGetErrorString(ctypes.c_int(err)).decode("utf-8")
raise RuntimeError(f"HIP error code {err}: {error_string}")
if _is_amd_backend:
gpu_runtime.hipGetErrorString.restype = ctypes.c_char_p
error_string = gpu_runtime.hipGetErrorString(ctypes.c_int(err)).decode("utf-8")
raise RuntimeError(f"HIP error code {err}: {error_string}")
else:
gpu_runtime.cudaGetErrorString.restype = ctypes.c_char_p
error_string = gpu_runtime.cudaGetErrorString(ctypes.c_int(err)).decode("utf-8")
raise RuntimeError(f"CUDA error code {err}: {error_string}")


def get_ipc_handle_size():
"""Return the IPC handle size for the current backend."""
return 64 if _is_amd_backend else 128

class hipIpcMemHandle_t(ctypes.Structure):
_fields_ = [("reserved", ctypes.c_char * 64)]

class gpuIpcMemHandle_t(ctypes.Structure):
_fields_ = [("reserved", ctypes.c_char * get_ipc_handle_size())]


def open_ipc_handle(ipc_handle_data, rank):
ptr = ctypes.c_void_p()
hipIpcMemLazyEnablePeerAccess = ctypes.c_uint(1)
hip_runtime.hipIpcOpenMemHandle.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
hipIpcMemHandle_t,
ctypes.c_uint,
]
handle_size = get_ipc_handle_size()

if _is_amd_backend:
hipIpcMemLazyEnablePeerAccess = ctypes.c_uint(1)
gpu_runtime.hipIpcOpenMemHandle.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
gpuIpcMemHandle_t,
ctypes.c_uint,
]
else:
gpu_runtime.cudaIpcOpenMemHandle.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
gpuIpcMemHandle_t,
ctypes.c_uint,
]
cudaIpcMemLazyEnablePeerAccess = ctypes.c_uint(1)

if isinstance(ipc_handle_data, np.ndarray):
if ipc_handle_data.dtype != np.uint8 or ipc_handle_data.size != 64:
raise ValueError("ipc_handle_data must be a 64-element uint8 numpy array")
if ipc_handle_data.dtype != np.uint8 or ipc_handle_data.size != handle_size:
raise ValueError(f"ipc_handle_data must be a {handle_size}-element uint8 numpy array")
ipc_handle_bytes = ipc_handle_data.tobytes()
ipc_handle_data = (ctypes.c_char * 64).from_buffer_copy(ipc_handle_bytes)
ipc_handle_data = (ctypes.c_char * handle_size).from_buffer_copy(ipc_handle_bytes)
else:
raise TypeError("ipc_handle_data must be a numpy.ndarray of dtype uint8 with 64 elements")
raise TypeError(f"ipc_handle_data must be a numpy.ndarray of dtype uint8 with {handle_size} elements")

raw_memory = ctypes.create_string_buffer(64)
ctypes.memset(raw_memory, 0x00, 64)
ipc_handle_struct = hipIpcMemHandle_t.from_buffer(raw_memory)
raw_memory = ctypes.create_string_buffer(handle_size)
ctypes.memset(raw_memory, 0x00, handle_size)
ipc_handle_struct = gpuIpcMemHandle_t.from_buffer(raw_memory)
ipc_handle_data_bytes = bytes(ipc_handle_data)
ctypes.memmove(raw_memory, ipc_handle_data_bytes, 64)

hip_try(
hip_runtime.hipIpcOpenMemHandle(
ctypes.byref(ptr),
ipc_handle_struct,
hipIpcMemLazyEnablePeerAccess,
ctypes.memmove(raw_memory, ipc_handle_data_bytes, handle_size)

if _is_amd_backend:
gpu_try(
gpu_runtime.hipIpcOpenMemHandle(
ctypes.byref(ptr),
ipc_handle_struct,
hipIpcMemLazyEnablePeerAccess,
)
)
else:
gpu_try(
gpu_runtime.cudaIpcOpenMemHandle(
ctypes.byref(ptr),
ipc_handle_struct,
cudaIpcMemLazyEnablePeerAccess,
)
)
)

return ptr.value


def get_ipc_handle(ptr, rank):
ipc_handle = hipIpcMemHandle_t()
hip_try(hip_runtime.hipIpcGetMemHandle(ctypes.byref(ipc_handle), ptr))
ipc_handle = gpuIpcMemHandle_t()
if _is_amd_backend:
gpu_try(gpu_runtime.hipIpcGetMemHandle(ctypes.byref(ipc_handle), ptr))
else:
gpu_try(gpu_runtime.cudaIpcGetMemHandle(ctypes.byref(ipc_handle), ptr))
return ipc_handle


def count_devices():
device_count = ctypes.c_int()
hip_try(hip_runtime.hipGetDeviceCount(ctypes.byref(device_count)))
if _is_amd_backend:
gpu_try(gpu_runtime.hipGetDeviceCount(ctypes.byref(device_count)))
else:
gpu_try(gpu_runtime.cudaGetDeviceCount(ctypes.byref(device_count)))
return device_count.value


def set_device(gpu_id):
hip_try(hip_runtime.hipSetDevice(gpu_id))
if _is_amd_backend:
gpu_try(gpu_runtime.hipSetDevice(gpu_id))
else:
gpu_try(gpu_runtime.cudaSetDevice(gpu_id))


def get_device_id():
device_id = ctypes.c_int()
hip_try(hip_runtime.hipGetDevice(ctypes.byref(device_id)))
if _is_amd_backend:
gpu_try(gpu_runtime.hipGetDevice(ctypes.byref(device_id)))
else:
gpu_try(gpu_runtime.cudaGetDevice(ctypes.byref(device_id)))
return device_id.value


def get_cu_count(device_id=None):
if device_id is None:
device_id = get_device_id()

hipDeviceAttributeMultiprocessorCount = 63
cu_count = ctypes.c_int()

hip_try(hip_runtime.hipDeviceGetAttribute(ctypes.byref(cu_count), hipDeviceAttributeMultiprocessorCount, device_id))
if _is_amd_backend:
hipDeviceAttributeMultiprocessorCount = 63
gpu_try(
gpu_runtime.hipDeviceGetAttribute(ctypes.byref(cu_count), hipDeviceAttributeMultiprocessorCount, device_id)
)
else:
cudaDevAttrMultiProcessorCount = 16
gpu_try(gpu_runtime.cudaDeviceGetAttribute(ctypes.byref(cu_count), cudaDevAttrMultiProcessorCount, device_id))

return cu_count.value


def get_rocm_version():
if not _is_amd_backend:
# Not applicable for CUDA
return (-1, -1)

major, minor = -1, -1

# Try hipconfig --path first
Expand Down Expand Up @@ -119,47 +182,76 @@ def get_rocm_version():


def get_wall_clock_rate(device_id):
hipDeviceAttributeWallClockRate = 10017
wall_clock_rate = ctypes.c_int()
status = hip_runtime.hipDeviceGetAttribute(
ctypes.byref(wall_clock_rate), hipDeviceAttributeWallClockRate, device_id
)
hip_try(status)

if _is_amd_backend:
hipDeviceAttributeWallClockRate = 10017
status = gpu_runtime.hipDeviceGetAttribute(
ctypes.byref(wall_clock_rate), hipDeviceAttributeWallClockRate, device_id
)
else:
cudaDevAttrClockRate = 13
status = gpu_runtime.cudaDeviceGetAttribute(ctypes.byref(wall_clock_rate), cudaDevAttrClockRate, device_id)

gpu_try(status)
return wall_clock_rate.value


def get_arch_string(device_id=None):
if device_id is None:
device_id = get_device_id()
arch_full = torch.cuda.get_device_properties(device_id).gcnArchName
arch_name = arch_full.split(":")[0]
return arch_name

if _is_amd_backend:
arch_full = torch.cuda.get_device_properties(device_id).gcnArchName
arch_name = arch_full.split(":")[0]
return arch_name
else:
# For CUDA, return compute capability
props = torch.cuda.get_device_properties(device_id)
return f"sm_{props.major}{props.minor}"


def get_num_xcc(device_id=None):
if device_id is None:
device_id = get_device_id()

if not _is_amd_backend:
# XCC is AMD-specific, return 1 for CUDA
return 1

rocm_major, _ = get_rocm_version()
if rocm_major < 7:
return 8
hipDeviceAttributeNumberOfXccs = 10018
xcc_count = ctypes.c_int()
hip_try(hip_runtime.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id))
gpu_try(gpu_runtime.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id))
return xcc_count.value


def malloc_fine_grained(size):
hipDeviceMallocFinegrained = 0x1
ptr = ctypes.c_void_p()
hip_try(hip_runtime.hipExtMallocWithFlags(ctypes.byref(ptr), size, hipDeviceMallocFinegrained))

if _is_amd_backend:
hipDeviceMallocFinegrained = 0x1
gpu_try(gpu_runtime.hipExtMallocWithFlags(ctypes.byref(ptr), size, hipDeviceMallocFinegrained))
else:
# CUDA doesn't have direct equivalent, use regular malloc
gpu_try(gpu_runtime.cudaMalloc(ctypes.byref(ptr), size))

return ptr


def hip_malloc(size):
ptr = ctypes.c_void_p()
hip_try(hip_runtime.hipMalloc(ctypes.byref(ptr), size))
if _is_amd_backend:
gpu_try(gpu_runtime.hipMalloc(ctypes.byref(ptr), size))
else:
gpu_try(gpu_runtime.cudaMalloc(ctypes.byref(ptr), size))
return ptr


def hip_free(ptr):
hip_try(hip_runtime.hipFree(ptr))
if _is_amd_backend:
gpu_try(gpu_runtime.hipFree(ptr))
else:
gpu_try(gpu_runtime.cudaFree(ptr))
10 changes: 7 additions & 3 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_ipc_handle,
open_ipc_handle,
get_wall_clock_rate,
get_ipc_handle_size,
)
import numpy as np
import math
Expand Down Expand Up @@ -89,13 +90,16 @@ def __init__(self, heap_size=1 << 30):

heap_bases = np.zeros(num_ranks, dtype=np.uint64)
heap_bases[cur_rank] = heap_base
ipc_handles = np.zeros((num_ranks, 64), dtype=np.uint8)
ipc_handle_size = get_ipc_handle_size()
ipc_handles = np.zeros((num_ranks, ipc_handle_size), dtype=np.uint8)
ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank)

distributed_barrier()

all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8))
all_heap_bases = distributed_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64))
all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8).copy())
heap_base_bytes = np.array([heap_bases[cur_rank]], dtype=np.uint64).tobytes()
all_heap_bases_bytes = distributed_allgather(np.frombuffer(heap_base_bytes, dtype=np.uint8).copy())
all_heap_bases = np.frombuffer(all_heap_bases_bytes.tobytes(), dtype=np.uint64).reshape(num_ranks, -1)

distributed_barrier()

Expand Down
Loading