Skip to content

Commit bd48cf9

Browse files
committed
Implement DeviceMemoryResource.peer_accessible_by
1 parent eb774e7 commit bd48cf9

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

cuda_core/cuda/core/experimental/_device.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,27 @@ class Device:
10351035
bus_id = handle_return(runtime.cudaDeviceGetPCIBusId(13, self._id))
10361036
return bus_id[:12].decode()
10371037

1038+
def can_access_peer(self, peer: Device | int) -> bool:
1039+
"""Check if this device can access memory from the specified peer device.
1040+
1041+
Queries whether peer-to-peer memory access is supported between this
1042+
device and the specified peer device.
1043+
1044+
Parameters
1045+
----------
1046+
peer : Device | int
1047+
The peer device to check accessibility to. Can be a Device object or device ID.
1048+
"""
1049+
peer = Device(peer)
1050+
cdef int d1 = <int> self.device_id
1051+
cdef int d2 = <int> peer.device_id
1052+
if d1 == d2:
1053+
return True
1054+
cdef int value = 0
1055+
with nogil:
1056+
HANDLE_RETURN(cydriver.cuDeviceCanAccessPeer(&value, d1, d2))
1057+
return bool(value)
1058+
10381059
@property
10391060
def uuid(self) -> str:
10401061
"""Return a UUID for the device.

cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ cdef class DeviceMemoryResource(MemoryResource):
1414
bint _mempool_owned
1515
IPCData _ipc_data
1616
object _attributes
17+
object _peer_accessible_by
1718
object __weakref__
19+
20+
21+
cpdef DMR_mempool_get_access(DeviceMemoryResource, int)

cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from __future__ import annotations
66

77
from libc.limits cimport ULLONG_MAX
88
from libc.stdint cimport uintptr_t
9+
from libc.stdlib cimport malloc, free
910
from libc.string cimport memset
1011

1112
from cuda.bindings cimport cydriver
@@ -222,6 +223,7 @@ cdef class DeviceMemoryResource(MemoryResource):
222223
self._mempool_owned = False
223224
self._ipc_data = None
224225
self._attributes = None
226+
self._peer_accessible_by = ()
225227

226228
def __init__(self, device_id: Device | int, options=None):
227229
from .._device import Device
@@ -408,6 +410,69 @@ cdef class DeviceMemoryResource(MemoryResource):
408410
"""
409411
return getattr(self._ipc_data, 'uuid', None)
410412

413+
@property
414+
def peer_accessible_by(self):
415+
"""
416+
Get or set the devices that can access allocations from this memory
417+
pool. Access can be modified at any time and affects all allocations
418+
from this memory pool.
419+
420+
Returns a tuple of sorted device IDs that currently have peer access to
421+
allocations from this memory pool.
422+
423+
When setting, accepts a sequence of Device objects or device IDs.
424+
Setting to an empty sequence revokes all peer access.
425+
426+
Examples
427+
--------
428+
>>> dmr = DeviceMemoryResource(0)
429+
>>> dmr.peer_accessible_by = [1] # Grant access to device 1
430+
>>> assert dmr.peer_accessible_by == (1,)
431+
>>> dmr.peer_accessible_by = [] # Revoke access
432+
"""
433+
return self._peer_accessible_by
434+
435+
@peer_accessible_by.setter
436+
def peer_accessible_by(self, devices):
437+
"""Set which devices can access this memory pool."""
438+
from .._device import Device
439+
440+
# Convert all devices to device IDs
441+
cdef set target_ids = set([Device(dev).device_id for dev in devices])
442+
target_ids.discard(self._dev_id) # exclude this device from peer access list
443+
cdef set cur_ids = set(self._peer_accessible_by)
444+
cdef set to_add = target_ids - cur_ids
445+
cdef set to_rm = cur_ids - target_ids
446+
cdef size_t count = len(to_add) + len(to_rm) # transaction size
447+
cdef cydriver.CUmemAccessDesc* access_desc = NULL
448+
cdef size_t i = 0
449+
450+
if count > 0:
451+
access_desc = <cydriver.CUmemAccessDesc*>malloc(count * sizeof(cydriver.CUmemAccessDesc))
452+
if access_desc == NULL:
453+
raise MemoryError("Failed to allocate memory for access descriptors")
454+
455+
try:
456+
for dev_id in to_add:
457+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
458+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
459+
access_desc[i].location.id = dev_id
460+
i += 1
461+
462+
for dev_id in to_rm:
463+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
464+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
465+
access_desc[i].location.id = dev_id
466+
i += 1
467+
468+
with nogil:
469+
HANDLE_RETURN(cydriver.cuMemPoolSetAccess(self._handle, access_desc, count))
470+
finally:
471+
if access_desc != NULL:
472+
free(access_desc)
473+
474+
self._peer_accessible_by = tuple(target_ids)
475+
411476

412477
# DeviceMemoryResource Implementation
413478
# -----------------------------------
@@ -515,6 +580,11 @@ cdef inline DMR_close(DeviceMemoryResource self):
515580
if self._handle == NULL:
516581
return
517582

583+
# This works around nvbug 5698116. When a memory pool handle is recycled
584+
# the new handle inherits the peer access state of the previous handle.
585+
if self._peer_accessible_by:
586+
self.peer_accessible_by = []
587+
518588
try:
519589
if self._mempool_owned:
520590
with nogil:
@@ -525,3 +595,40 @@ cdef inline DMR_close(DeviceMemoryResource self):
525595
self._attributes = None
526596
self._mempool_owned = False
527597
self._ipc_data = None
598+
self._peer_accessible_by = ()
599+
600+
601+
# Note: this is referenced in instructions to debug nvbug 5698116.
602+
cpdef DMR_mempool_get_access(DeviceMemoryResource dmr, int device_id):
603+
"""
604+
Probes peer access from the given device using cuMemPoolGetAccess.
605+
606+
Parameters
607+
----------
608+
device_id : int or Device
609+
The device to query access for.
610+
611+
Returns
612+
-------
613+
str
614+
Access permissions: "rw" for read-write, "r" for read-only, "" for no access.
615+
"""
616+
from .._device import Device
617+
618+
cdef int dev_id = Device(device_id).device_id
619+
cdef cydriver.CUmemAccess_flags flags
620+
cdef cydriver.CUmemLocation location
621+
622+
location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
623+
location.id = dev_id
624+
625+
with nogil:
626+
HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, dmr._handle, &location))
627+
628+
if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE:
629+
return "rw"
630+
elif flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ:
631+
return "r"
632+
else:
633+
return ""
634+
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from cuda.core.experimental import Device, DeviceMemoryResource
2+
from cuda.core.experimental._utils.cuda_utils import CUDAError
3+
from helpers.buffers import compare_buffer_to_constant, make_scratch_buffer, PatternGen
4+
import cuda.core.experimental
5+
import pytest
6+
import itertools
7+
8+
NBYTES = 1024
9+
10+
def _mempool_device_impl(num):
11+
num_devices = len(cuda.core.experimental.system.devices)
12+
if num_devices < num:
13+
pytest.skip("Test requires at least {num} GPUs")
14+
15+
devs = [Device(i) for i in range(num)]
16+
for i in reversed(range(num)):
17+
devs[i].set_current()
18+
19+
if not all(devs[i].can_access_peer(j) for i in range(num) for j in range(num)):
20+
pytest.skip("Test requires GPUs with peer access")
21+
22+
if not all(devs[i].properties.memory_pools_supported for i in range(num)):
23+
pytest.skip("Device does not support mempool operations")
24+
25+
return devs
26+
27+
@pytest.fixture
28+
def mempool_device_x2():
29+
"""Fixture that provides two devices if available, otherwise skips test."""
30+
return _mempool_device_impl(2)
31+
32+
@pytest.fixture
33+
def mempool_device_x3():
34+
"""Fixture that provides three devices if available, otherwise skips test."""
35+
return _mempool_device_impl(3)
36+
37+
38+
def test_peer_access_basic(mempool_device_x2):
39+
"""Basic tests for dmr.peer_accessible_by."""
40+
dev0, dev1 = mempool_device_x2
41+
zero_on_dev0 = make_scratch_buffer(dev0, 0, NBYTES)
42+
one_on_dev0 = make_scratch_buffer(dev0, 1, NBYTES)
43+
stream_on_dev0 = dev0.create_stream()
44+
dmr_on_dev1 = DeviceMemoryResource(dev1)
45+
buf_on_dev1 = dmr_on_dev1.allocate(NBYTES)
46+
47+
# No access at first.
48+
assert 0 not in dmr_on_dev1.peer_accessible_by
49+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
50+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
51+
52+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
53+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
54+
55+
# Allow access to device 1's allocations from device 0.
56+
dmr_on_dev1.peer_accessible_by = [dev0]
57+
assert 0 in dmr_on_dev1.peer_accessible_by
58+
compare_buffer_to_constant(zero_on_dev0, 0)
59+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
60+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
61+
stream_on_dev0.sync()
62+
compare_buffer_to_constant(zero_on_dev0, 1)
63+
64+
# Revoke access
65+
dmr_on_dev1.peer_accessible_by = []
66+
assert 0 not in dmr_on_dev1.peer_accessible_by
67+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
68+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
69+
70+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
71+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
72+
73+
74+
def test_peer_access_property_x2(mempool_device_x2):
75+
"""The the dmr.peer_accessible_by property (but not its functionality)."""
76+
# The peer access list is a sorted tuple and always excludes the self
77+
# device.
78+
dev0, dev1 = mempool_device_x2
79+
dmr = DeviceMemoryResource(dev0)
80+
81+
def check(expected):
82+
assert isinstance(dmr.peer_accessible_by, tuple)
83+
assert dmr.peer_accessible_by == expected
84+
85+
# No access to begin with.
86+
check(expected=())
87+
dmr.peer_accessible_by = (0,) ; check(expected=())
88+
dmr.peer_accessible_by = (1,) ; check(expected=(1,))
89+
dmr.peer_accessible_by = (0,1) ; check(expected=(1,))
90+
dmr.peer_accessible_by = () ; check(expected=())
91+
dmr.peer_accessible_by = [0,1] ; check(expected=(1,)) # list
92+
dmr.peer_accessible_by = set() ; check(expected=()) # set
93+
dmr.peer_accessible_by = [1,1,1,1,1] ; check(expected=(1,))
94+
95+
with pytest.raises(ValueError, match=r"device_id must be \>\= 0"):
96+
dmr.peer_accessible_by = [-1] # device ID out of bounds
97+
98+
num_devices = len(cuda.core.experimental.system.devices)
99+
100+
with pytest.raises(ValueError, match=r"device_id must be within \[0, \d+\)"):
101+
dmr.peer_accessible_by = [num_devices] # device ID out of bounds
102+
103+
104+
def test_peer_access_transitions(mempool_device_x3):
105+
"""Advanced tests for dmr.peer_accessible_by."""
106+
107+
# Check all transitions between peer access states. The implementation
108+
# performs transactions that add or remove access as needed. This test
109+
# ensure that is working as expected.
110+
111+
# Doing everything from the point-of-view of device 0, there are four
112+
# access states:
113+
#
114+
# [(), (1,), (2,), (1, 2)]
115+
#
116+
# and 4^2 = 16 transitions.
117+
118+
devs = mempool_device_x3 # Three devices
119+
120+
# Allocate per-device resources.
121+
streams = [dev.create_stream() for dev in devs]
122+
pgens = [PatternGen(devs[i], NBYTES, streams[i]) for i in range(3)]
123+
dmrs = [DeviceMemoryResource(dev) for dev in devs]
124+
bufs = [dmr.allocate(NBYTES) for dmr in dmrs]
125+
126+
def verify_state(state, pattern_seed):
127+
"""
128+
Verify an access state from the POV of device 0. E.g., (1,) means
129+
device 1 has access but device 2 does not.
130+
"""
131+
# Populate device 0's buffer with a new pattern.
132+
devs[0].set_current()
133+
pgens[0].fill_buffer(bufs[0], seed=pattern_seed)
134+
streams[0].sync()
135+
136+
for peer in [1, 2]:
137+
devs[peer].set_current()
138+
if peer in state:
139+
# Peer device has access to 0's allocation
140+
bufs[peer].copy_from(bufs[0], stream=streams[peer])
141+
# Check the result on the peer device.
142+
pgens[peer].verify_buffer(bufs[peer], seed=pattern_seed)
143+
else:
144+
# Peer device has no access to 0's allocation
145+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
146+
bufs[peer].copy_from(bufs[0], stream=streams[peer])
147+
148+
# For each transition, set the access state before and after, checking for
149+
# the expected peer access capabilities at each stop.
150+
pattern_seed = 0
151+
states = [(), (1,), (2,), (1, 2)]
152+
transitions = [(s0, s1) for s0 in states for s1 in states if s0 != s1]
153+
for init_state, final_state in transitions:
154+
dmrs[0].peer_accessible_by = init_state
155+
assert dmrs[0].peer_accessible_by == init_state
156+
verify_state(init_state, pattern_seed)
157+
pattern_seed += 1
158+
159+
dmrs[0].peer_accessible_by = final_state
160+
assert dmrs[0].peer_accessible_by == final_state
161+
verify_state(final_state, pattern_seed)
162+
pattern_seed += 1
163+

0 commit comments

Comments
 (0)