Skip to content

Commit a4b285a

Browse files
authored
Add peer access control for DeviceMemoryResource (#1289)
* Ignore .cursorrules * Implement DeviceMemoryResource.peer_accessible_by * Add a check for device accessibility in peer_accessible_by.
1 parent ee03396 commit a4b285a

File tree

5 files changed

+307
-0
lines changed

5 files changed

+307
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,6 @@ cython_debug/
185185
# pixi environments
186186
.pixi/*
187187
!.pixi/config.toml
188+
189+
# Cursor
190+
.cursorrules

cuda_core/cuda/core/experimental/_device.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,27 @@ class Device:
10351035
bus_id = handle_return(runtime.cudaDeviceGetPCIBusId(13, self._id))
10361036
return bus_id[:12].decode()
10371037

1038+
def can_access_peer(self, peer: Device | int) -> bool:
1039+
"""Check if this device can access memory from the specified peer device.
1040+
1041+
Queries whether peer-to-peer memory access is supported between this
1042+
device and the specified peer device.
1043+
1044+
Parameters
1045+
----------
1046+
peer : Device | int
1047+
The peer device to check accessibility to. Can be a Device object or device ID.
1048+
"""
1049+
peer = Device(peer)
1050+
cdef int d1 = <int> self.device_id
1051+
cdef int d2 = <int> peer.device_id
1052+
if d1 == d2:
1053+
return True
1054+
cdef int value = 0
1055+
with nogil:
1056+
HANDLE_RETURN(cydriver.cuDeviceCanAccessPeer(&value, d1, d2))
1057+
return bool(value)
1058+
10381059
@property
10391060
def uuid(self) -> str:
10401061
"""Return a UUID for the device.

cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ cdef class DeviceMemoryResource(MemoryResource):
1414
bint _mempool_owned
1515
IPCData _ipc_data
1616
object _attributes
17+
object _peer_accessible_by
1718
object __weakref__
19+
20+
21+
cpdef DMR_mempool_get_access(DeviceMemoryResource, int)

cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from __future__ import annotations
66

77
from libc.limits cimport ULLONG_MAX
88
from libc.stdint cimport uintptr_t
9+
from libc.stdlib cimport malloc, free
910
from libc.string cimport memset
1011

1112
from cuda.bindings cimport cydriver
@@ -222,6 +223,7 @@ cdef class DeviceMemoryResource(MemoryResource):
222223
self._mempool_owned = False
223224
self._ipc_data = None
224225
self._attributes = None
226+
self._peer_accessible_by = ()
225227

226228
def __init__(self, device_id: Device | int, options=None):
227229
from .._device import Device
@@ -408,6 +410,73 @@ cdef class DeviceMemoryResource(MemoryResource):
408410
"""
409411
return getattr(self._ipc_data, 'uuid', None)
410412

413+
@property
414+
def peer_accessible_by(self):
415+
"""
416+
Get or set the devices that can access allocations from this memory
417+
pool. Access can be modified at any time and affects all allocations
418+
from this memory pool.
419+
420+
Returns a tuple of sorted device IDs that currently have peer access to
421+
allocations from this memory pool.
422+
423+
When setting, accepts a sequence of Device objects or device IDs.
424+
Setting to an empty sequence revokes all peer access.
425+
426+
Examples
427+
--------
428+
>>> dmr = DeviceMemoryResource(0)
429+
>>> dmr.peer_accessible_by = [1] # Grant access to device 1
430+
>>> assert dmr.peer_accessible_by == (1,)
431+
>>> dmr.peer_accessible_by = [] # Revoke access
432+
"""
433+
return self._peer_accessible_by
434+
435+
@peer_accessible_by.setter
436+
def peer_accessible_by(self, devices):
437+
"""Set which devices can access this memory pool."""
438+
from .._device import Device
439+
440+
# Convert all devices to device IDs
441+
cdef set[int] target_ids = {Device(dev).device_id for dev in devices}
442+
target_ids.discard(self._dev_id) # exclude this device from peer access list
443+
this_dev = Device(self._dev_id)
444+
cdef list bad = [dev for dev in target_ids if not this_dev.can_access_peer(dev)]
445+
if bad:
446+
raise ValueError(f"Device {self._dev_id} cannot access peer(s): {', '.join(map(str, bad))}")
447+
cdef set[int] cur_ids = set(self._peer_accessible_by)
448+
cdef set[int] to_add = target_ids - cur_ids
449+
cdef set[int] to_rm = cur_ids - target_ids
450+
cdef size_t count = len(to_add) + len(to_rm) # transaction size
451+
cdef cydriver.CUmemAccessDesc* access_desc = NULL
452+
cdef size_t i = 0
453+
454+
if count > 0:
455+
access_desc = <cydriver.CUmemAccessDesc*>malloc(count * sizeof(cydriver.CUmemAccessDesc))
456+
if access_desc == NULL:
457+
raise MemoryError("Failed to allocate memory for access descriptors")
458+
459+
try:
460+
for dev_id in to_add:
461+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
462+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
463+
access_desc[i].location.id = dev_id
464+
i += 1
465+
466+
for dev_id in to_rm:
467+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
468+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
469+
access_desc[i].location.id = dev_id
470+
i += 1
471+
472+
with nogil:
473+
HANDLE_RETURN(cydriver.cuMemPoolSetAccess(self._handle, access_desc, count))
474+
finally:
475+
if access_desc != NULL:
476+
free(access_desc)
477+
478+
self._peer_accessible_by = tuple(target_ids)
479+
411480

412481
# DeviceMemoryResource Implementation
413482
# -----------------------------------
@@ -515,6 +584,11 @@ cdef inline DMR_close(DeviceMemoryResource self):
515584
if self._handle == NULL:
516585
return
517586

587+
# This works around nvbug 5698116. When a memory pool handle is recycled
588+
# the new handle inherits the peer access state of the previous handle.
589+
if self._peer_accessible_by:
590+
self.peer_accessible_by = []
591+
518592
try:
519593
if self._mempool_owned:
520594
with nogil:
@@ -525,3 +599,39 @@ cdef inline DMR_close(DeviceMemoryResource self):
525599
self._attributes = None
526600
self._mempool_owned = False
527601
self._ipc_data = None
602+
self._peer_accessible_by = ()
603+
604+
605+
# Note: this is referenced in instructions to debug nvbug 5698116.
606+
cpdef DMR_mempool_get_access(DeviceMemoryResource dmr, int device_id):
607+
"""
608+
Probes peer access from the given device using cuMemPoolGetAccess.
609+
610+
Parameters
611+
----------
612+
device_id : int or Device
613+
The device to query access for.
614+
615+
Returns
616+
-------
617+
str
618+
Access permissions: "rw" for read-write, "r" for read-only, "" for no access.
619+
"""
620+
from .._device import Device
621+
622+
cdef int dev_id = Device(device_id).device_id
623+
cdef cydriver.CUmemAccess_flags flags
624+
cdef cydriver.CUmemLocation location
625+
626+
location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
627+
location.id = dev_id
628+
629+
with nogil:
630+
HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, dmr._handle, &location))
631+
632+
if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE:
633+
return "rw"
634+
elif flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ:
635+
return "r"
636+
else:
637+
return ""
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import cuda.core.experimental
5+
import pytest
6+
from cuda.core.experimental import Device, DeviceMemoryResource
7+
from cuda.core.experimental._utils.cuda_utils import CUDAError
8+
from helpers.buffers import PatternGen, compare_buffer_to_constant, make_scratch_buffer
9+
10+
NBYTES = 1024
11+
12+
13+
def _mempool_device_impl(num):
14+
num_devices = len(cuda.core.experimental.system.devices)
15+
if num_devices < num:
16+
pytest.skip("Test requires at least {num} GPUs")
17+
18+
devs = [Device(i) for i in range(num)]
19+
for i in reversed(range(num)):
20+
devs[i].set_current()
21+
22+
if not all(devs[i].can_access_peer(j) for i in range(num) for j in range(num)):
23+
pytest.skip("Test requires GPUs with peer access")
24+
25+
if not all(devs[i].properties.memory_pools_supported for i in range(num)):
26+
pytest.skip("Device does not support mempool operations")
27+
28+
return devs
29+
30+
31+
@pytest.fixture
32+
def mempool_device_x2():
33+
"""Fixture that provides two devices if available, otherwise skips test."""
34+
return _mempool_device_impl(2)
35+
36+
37+
@pytest.fixture
38+
def mempool_device_x3():
39+
"""Fixture that provides three devices if available, otherwise skips test."""
40+
return _mempool_device_impl(3)
41+
42+
43+
def test_peer_access_basic(mempool_device_x2):
44+
"""Basic tests for dmr.peer_accessible_by."""
45+
dev0, dev1 = mempool_device_x2
46+
zero_on_dev0 = make_scratch_buffer(dev0, 0, NBYTES)
47+
one_on_dev0 = make_scratch_buffer(dev0, 1, NBYTES)
48+
stream_on_dev0 = dev0.create_stream()
49+
dmr_on_dev1 = DeviceMemoryResource(dev1)
50+
buf_on_dev1 = dmr_on_dev1.allocate(NBYTES)
51+
52+
# No access at first.
53+
assert 0 not in dmr_on_dev1.peer_accessible_by
54+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
55+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
56+
57+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
58+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
59+
60+
# Allow access to device 1's allocations from device 0.
61+
dmr_on_dev1.peer_accessible_by = [dev0]
62+
assert 0 in dmr_on_dev1.peer_accessible_by
63+
compare_buffer_to_constant(zero_on_dev0, 0)
64+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
65+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
66+
stream_on_dev0.sync()
67+
compare_buffer_to_constant(zero_on_dev0, 1)
68+
69+
# Revoke access
70+
dmr_on_dev1.peer_accessible_by = []
71+
assert 0 not in dmr_on_dev1.peer_accessible_by
72+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
73+
one_on_dev0.copy_to(buf_on_dev1, stream=stream_on_dev0)
74+
75+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
76+
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)
77+
78+
79+
def test_peer_access_property_x2(mempool_device_x2):
80+
"""The the dmr.peer_accessible_by property (but not its functionality)."""
81+
# The peer access list is a sorted tuple and always excludes the self
82+
# device.
83+
dev0, dev1 = mempool_device_x2
84+
dmr = DeviceMemoryResource(dev0)
85+
86+
def check(expected):
87+
assert isinstance(dmr.peer_accessible_by, tuple)
88+
assert dmr.peer_accessible_by == expected
89+
90+
# No access to begin with.
91+
check(expected=())
92+
# fmt: off
93+
dmr.peer_accessible_by = (0,) ; check(expected=()) # noqa: E702
94+
dmr.peer_accessible_by = (1,) ; check(expected=(1,)) # noqa: E702
95+
dmr.peer_accessible_by = (0, 1) ; check(expected=(1,)) # noqa: E702
96+
dmr.peer_accessible_by = () ; check(expected=()) # noqa: E702
97+
dmr.peer_accessible_by = [0, 1] ; check(expected=(1,)) # noqa: E702
98+
dmr.peer_accessible_by = set() ; check(expected=()) # noqa: E702
99+
dmr.peer_accessible_by = [1, 1, 1, 1, 1] ; check(expected=(1,)) # noqa: E702
100+
# fmt: on
101+
102+
with pytest.raises(ValueError, match=r"device_id must be \>\= 0"):
103+
dmr.peer_accessible_by = [-1] # device ID out of bounds
104+
105+
num_devices = len(cuda.core.experimental.system.devices)
106+
107+
with pytest.raises(ValueError, match=r"device_id must be within \[0, \d+\)"):
108+
dmr.peer_accessible_by = [num_devices] # device ID out of bounds
109+
110+
111+
def test_peer_access_transitions(mempool_device_x3):
112+
"""Advanced tests for dmr.peer_accessible_by."""
113+
114+
# Check all transitions between peer access states. The implementation
115+
# performs transactions that add or remove access as needed. This test
116+
# ensures that that is working as expected.
117+
118+
# Doing everything from the point-of-view of device 0, there are four
119+
# access states:
120+
#
121+
# [(), (1,), (2,), (1, 2)]
122+
#
123+
# and 4^2-4 = 12 non-identity transitions.
124+
125+
devs = mempool_device_x3 # Three devices
126+
127+
# Allocate per-device resources.
128+
streams = [dev.create_stream() for dev in devs]
129+
pgens = [PatternGen(devs[i], NBYTES, streams[i]) for i in range(3)]
130+
dmrs = [DeviceMemoryResource(dev) for dev in devs]
131+
bufs = [dmr.allocate(NBYTES) for dmr in dmrs]
132+
133+
def verify_state(state, pattern_seed):
134+
"""
135+
Verify an access state from the POV of device 0. E.g., (1,) means
136+
device 1 has access but device 2 does not.
137+
"""
138+
# Populate device 0's buffer with a new pattern.
139+
devs[0].set_current()
140+
pgens[0].fill_buffer(bufs[0], seed=pattern_seed)
141+
streams[0].sync()
142+
143+
for peer in [1, 2]:
144+
devs[peer].set_current()
145+
if peer in state:
146+
# Peer device has access to 0's allocation
147+
bufs[peer].copy_from(bufs[0], stream=streams[peer])
148+
# Check the result on the peer device.
149+
pgens[peer].verify_buffer(bufs[peer], seed=pattern_seed)
150+
else:
151+
# Peer device has no access to 0's allocation
152+
with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"):
153+
bufs[peer].copy_from(bufs[0], stream=streams[peer])
154+
155+
# For each transition, set the access state before and after, checking for
156+
# the expected peer access capabilities at each stop.
157+
pattern_seed = 0
158+
states = [(), (1,), (2,), (1, 2)]
159+
transitions = [(s0, s1) for s0 in states for s1 in states if s0 != s1]
160+
for init_state, final_state in transitions:
161+
dmrs[0].peer_accessible_by = init_state
162+
assert dmrs[0].peer_accessible_by == init_state
163+
verify_state(init_state, pattern_seed)
164+
pattern_seed += 1
165+
166+
dmrs[0].peer_accessible_by = final_state
167+
assert dmrs[0].peer_accessible_by == final_state
168+
verify_state(final_state, pattern_seed)
169+
pattern_seed += 1

0 commit comments

Comments
 (0)