Skip to content

Commit f223c17

Browse files
authored
feat: add from_* style constructor classmethods to StridedMemoryView and make constructor amenable to future from_*-style constructors (#1224)
1 parent d108351 commit f223c17

File tree

3 files changed

+158
-136
lines changed

3 files changed

+158
-136
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 153 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,155 @@ cdef class StridedMemoryView:
8888
cdef DLTensor *dl_tensor
8989

9090
# Memoized properties
91-
cdef tuple _shape
92-
cdef tuple _strides
93-
cdef bint _strides_init # Has the strides tuple been init'ed?
94-
cdef object _dtype
95-
96-
def __init__(self, obj=None, stream_ptr=None):
97-
if obj is not None:
98-
# populate self's attributes
99-
if check_has_dlpack(obj):
100-
view_as_dlpack(obj, stream_ptr, self)
101-
else:
102-
view_as_cai(obj, stream_ptr, self)
91+
cdef:
92+
tuple _shape
93+
tuple _strides
94+
# a `None` value for _strides has defined meaning in dlpack and
95+
# the cuda array interface, meaning C order, contiguous.
96+
#
97+
# this flag helps prevent unnecessary recompuation of _strides
98+
bint _strides_init
99+
object _dtype
100+
101+
def __init__(
102+
self,
103+
*,
104+
ptr: intptr_t,
105+
device_id: int,
106+
is_device_accessible: bint,
107+
readonly: bint,
108+
metadata: object,
109+
exporting_obj: object,
110+
dl_tensor: intptr_t = 0,
111+
) -> None:
112+
self.ptr = ptr
113+
self.device_id = device_id
114+
self.is_device_accessible = is_device_accessible
115+
self.readonly = readonly
116+
self.metadata = metadata
117+
self.exporting_obj = exporting_obj
118+
self.dl_tensor = <DLTensor*>dl_tensor
119+
self._shape = None
120+
self._strides = None
121+
self._strides_init = False
122+
self._dtype = None
123+
124+
@classmethod
125+
def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
126+
cdef int dldevice, device_id
127+
cdef bint is_device_accessible, is_readonly
128+
is_device_accessible = False
129+
dldevice, device_id = obj.__dlpack_device__()
130+
if dldevice == _kDLCPU:
131+
assert device_id == 0
132+
device_id = -1
133+
if stream_ptr is None:
134+
raise BufferError("stream=None is ambiguous with view()")
135+
elif stream_ptr == -1:
136+
stream_ptr = None
137+
elif dldevice == _kDLCUDA:
138+
assert device_id >= 0
139+
is_device_accessible = True
140+
# no need to check other stream values, it's a pass-through
141+
if stream_ptr is None:
142+
raise BufferError("stream=None is ambiguous with view()")
143+
elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
144+
is_device_accessible = True
145+
# just do a pass-through without any checks, as pinned/managed memory can be
146+
# accessed on both host and device
103147
else:
104-
pass
148+
raise BufferError("device not supported")
149+
150+
cdef object capsule
151+
try:
152+
capsule = obj.__dlpack__(
153+
stream=int(stream_ptr) if stream_ptr else None,
154+
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
155+
except TypeError:
156+
capsule = obj.__dlpack__(
157+
stream=int(stream_ptr) if stream_ptr else None)
158+
159+
cdef void* data = NULL
160+
cdef DLTensor* dl_tensor
161+
cdef DLManagedTensorVersioned* dlm_tensor_ver
162+
cdef DLManagedTensor* dlm_tensor
163+
cdef const char *used_name
164+
if cpython.PyCapsule_IsValid(
165+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
166+
data = cpython.PyCapsule_GetPointer(
167+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
168+
dlm_tensor_ver = <DLManagedTensorVersioned*>data
169+
dl_tensor = &dlm_tensor_ver.dl_tensor
170+
is_readonly = (dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
171+
used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
172+
else:
173+
assert cpython.PyCapsule_IsValid(
174+
capsule, DLPACK_TENSOR_UNUSED_NAME)
175+
data = cpython.PyCapsule_GetPointer(
176+
capsule, DLPACK_TENSOR_UNUSED_NAME)
177+
dlm_tensor = <DLManagedTensor*>data
178+
dl_tensor = &dlm_tensor.dl_tensor
179+
is_readonly = False
180+
used_name = DLPACK_TENSOR_USED_NAME
181+
182+
cpython.PyCapsule_SetName(capsule, used_name)
183+
184+
return cls(
185+
ptr=<intptr_t>dl_tensor.data,
186+
device_id=int(device_id),
187+
is_device_accessible=is_device_accessible,
188+
readonly=is_readonly,
189+
metadata=capsule,
190+
exporting_obj=obj,
191+
dl_tensor=<intptr_t>dl_tensor,
192+
)
193+
194+
@classmethod
195+
def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
196+
cdef dict cai_data = obj.__cuda_array_interface__
197+
if cai_data["version"] < 3:
198+
raise BufferError("only CUDA Array Interface v3 or above is supported")
199+
if cai_data.get("mask") is not None:
200+
raise BufferError("mask is not supported")
201+
if stream_ptr is None:
202+
raise BufferError("stream=None is ambiguous with view()")
203+
204+
cdef intptr_t producer_s, consumer_s
205+
stream_ptr = int(stream_ptr)
206+
if stream_ptr != -1:
207+
stream = cai_data.get("stream")
208+
if stream is not None:
209+
producer_s = <intptr_t>(stream)
210+
consumer_s = <intptr_t>(stream_ptr)
211+
assert producer_s > 0
212+
# establish stream order
213+
if producer_s != consumer_s:
214+
e = handle_return(driver.cuEventCreate(
215+
driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
216+
handle_return(driver.cuEventRecord(e, producer_s))
217+
handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0))
218+
handle_return(driver.cuEventDestroy(e))
219+
220+
cdef intptr_t ptr = int(cai_data["data"][0])
221+
return cls(
222+
ptr=ptr,
223+
device_id=handle_return(
224+
driver.cuPointerGetAttribute(
225+
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
226+
ptr
227+
)
228+
),
229+
is_device_accessible=True,
230+
readonly=cai_data["data"][1],
231+
metadata=cai_data,
232+
exporting_obj=obj,
233+
)
234+
235+
@classmethod
236+
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
237+
if check_has_dlpack(obj):
238+
return cls.from_dlpack(obj, stream_ptr)
239+
return cls.from_cuda_array_interface(obj, stream_ptr)
105240

106241
def __dealloc__(self):
107242
if self.dl_tensor == NULL:
@@ -121,7 +256,7 @@ cdef class StridedMemoryView:
121256
dlm_tensor.deleter(dlm_tensor)
122257

123258
@property
124-
def shape(self) -> tuple[int]:
259+
def shape(self) -> tuple[int, ...]:
125260
if self._shape is None:
126261
if self.exporting_obj is not None:
127262
if self.dl_tensor != NULL:
@@ -136,7 +271,7 @@ cdef class StridedMemoryView:
136271
return self._shape
137272

138273
@property
139-
def strides(self) -> Optional[tuple[int]]:
274+
def strides(self) -> Optional[tuple[int, ...]]:
140275
cdef int itemsize
141276
if self._strides_init is False:
142277
if self.exporting_obj is not None:
@@ -206,8 +341,7 @@ cdef bint check_has_dlpack(obj) except*:
206341

207342

208343
cdef class _StridedMemoryViewProxy:
209-
210-
cdef:
344+
cdef readonly:
211345
object obj
212346
bint has_dlpack
213347

@@ -217,82 +351,11 @@ cdef class _StridedMemoryViewProxy:
217351

218352
cpdef StridedMemoryView view(self, stream_ptr=None):
219353
if self.has_dlpack:
220-
return view_as_dlpack(self.obj, stream_ptr)
354+
return StridedMemoryView.from_dlpack(self.obj, stream_ptr)
221355
else:
222-
return view_as_cai(self.obj, stream_ptr)
356+
return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr)
223357

224358

225-
cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
226-
cdef int dldevice, device_id
227-
cdef bint is_device_accessible, is_readonly
228-
is_device_accessible = False
229-
dldevice, device_id = obj.__dlpack_device__()
230-
if dldevice == _kDLCPU:
231-
assert device_id == 0
232-
device_id = -1
233-
if stream_ptr is None:
234-
raise BufferError("stream=None is ambiguous with view()")
235-
elif stream_ptr == -1:
236-
stream_ptr = None
237-
elif dldevice == _kDLCUDA:
238-
assert device_id >= 0
239-
is_device_accessible = True
240-
# no need to check other stream values, it's a pass-through
241-
if stream_ptr is None:
242-
raise BufferError("stream=None is ambiguous with view()")
243-
elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
244-
is_device_accessible = True
245-
# just do a pass-through without any checks, as pinned/managed memory can be
246-
# accessed on both host and device
247-
else:
248-
raise BufferError("device not supported")
249-
250-
cdef object capsule
251-
try:
252-
capsule = obj.__dlpack__(
253-
stream=int(stream_ptr) if stream_ptr else None,
254-
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
255-
except TypeError:
256-
capsule = obj.__dlpack__(
257-
stream=int(stream_ptr) if stream_ptr else None)
258-
259-
cdef void* data = NULL
260-
cdef DLTensor* dl_tensor
261-
cdef DLManagedTensorVersioned* dlm_tensor_ver
262-
cdef DLManagedTensor* dlm_tensor
263-
cdef const char *used_name
264-
if cpython.PyCapsule_IsValid(
265-
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
266-
data = cpython.PyCapsule_GetPointer(
267-
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
268-
dlm_tensor_ver = <DLManagedTensorVersioned*>data
269-
dl_tensor = &dlm_tensor_ver.dl_tensor
270-
is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0)
271-
used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
272-
elif cpython.PyCapsule_IsValid(
273-
capsule, DLPACK_TENSOR_UNUSED_NAME):
274-
data = cpython.PyCapsule_GetPointer(
275-
capsule, DLPACK_TENSOR_UNUSED_NAME)
276-
dlm_tensor = <DLManagedTensor*>data
277-
dl_tensor = &dlm_tensor.dl_tensor
278-
is_readonly = False
279-
used_name = DLPACK_TENSOR_USED_NAME
280-
else:
281-
assert False
282-
283-
cpython.PyCapsule_SetName(capsule, used_name)
284-
285-
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
286-
buf.dl_tensor = dl_tensor
287-
buf.metadata = capsule
288-
buf.ptr = <intptr_t>(dl_tensor.data)
289-
buf.device_id = device_id
290-
buf.is_device_accessible = is_device_accessible
291-
buf.readonly = is_readonly
292-
buf.exporting_obj = obj
293-
294-
return buf
295-
296359

297360
cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
298361
cdef int bits = dtype.bits
@@ -354,46 +417,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354417
return numpy.dtype(np_dtype)
355418

356419

357-
# Also generate for Python so we can test this code path
358-
cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
359-
cdef dict cai_data = obj.__cuda_array_interface__
360-
if cai_data["version"] < 3:
361-
raise BufferError("only CUDA Array Interface v3 or above is supported")
362-
if cai_data.get("mask") is not None:
363-
raise BufferError("mask is not supported")
364-
if stream_ptr is None:
365-
raise BufferError("stream=None is ambiguous with view()")
366-
367-
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
368-
buf.exporting_obj = obj
369-
buf.metadata = cai_data
370-
buf.dl_tensor = NULL
371-
buf.ptr, buf.readonly = cai_data["data"]
372-
buf.is_device_accessible = True
373-
buf.device_id = handle_return(
374-
driver.cuPointerGetAttribute(
375-
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
376-
buf.ptr))
377-
378-
cdef intptr_t producer_s, consumer_s
379-
stream_ptr = int(stream_ptr)
380-
if stream_ptr != -1:
381-
stream = cai_data.get("stream")
382-
if stream is not None:
383-
producer_s = <intptr_t>(stream)
384-
consumer_s = <intptr_t>(stream_ptr)
385-
assert producer_s > 0
386-
# establish stream order
387-
if producer_s != consumer_s:
388-
e = handle_return(driver.cuEventCreate(
389-
driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
390-
handle_return(driver.cuEventRecord(e, producer_s))
391-
handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0))
392-
handle_return(driver.cuEventDestroy(e))
393-
394-
return buf
395-
396-
397420
def args_viewable_as_strided_memory(tuple arg_indices):
398421
"""
399422
Decorator to create proxy objects to :obj:`StridedMemoryView` for the

cuda_core/tests/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,15 +609,15 @@ def test_strided_memory_view_leak():
609609
arr = np.zeros(1048576, dtype=np.uint8)
610610
before = sys.getrefcount(arr)
611611
for idx in range(10):
612-
StridedMemoryView(arr, stream_ptr=-1)
612+
StridedMemoryView.from_any_interface(arr, stream_ptr=-1)
613613
after = sys.getrefcount(arr)
614614
assert before == after
615615

616616

617617
def test_strided_memory_view_refcnt():
618618
# Use Fortran ordering so strides is used
619619
a = np.zeros((64, 4), dtype=np.uint8, order="F")
620-
av = StridedMemoryView(a, stream_ptr=-1)
620+
av = StridedMemoryView.from_any_interface(a, stream_ptr=-1)
621621
# segfaults if refcnt is wrong
622622
assert av.shape[0] == 64
623623
assert sys.getrefcount(av.shape) >= 2

cuda_core/tests/test_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
import pytest
1616
from cuda.core.experimental import Device
17-
from cuda.core.experimental._memoryview import view_as_cai
1817
from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory
1918

2019

@@ -78,7 +77,7 @@ def my_func(arr):
7877

7978
def test_strided_memory_view_cpu(self, in_arr):
8079
# stream_ptr=-1 means "the consumer does not care"
81-
view = StridedMemoryView(in_arr, stream_ptr=-1)
80+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=-1)
8281
self._check_view(view, in_arr)
8382

8483
def _check_view(self, view, in_arr):
@@ -147,7 +146,7 @@ def test_strided_memory_view_cpu(self, in_arr, use_stream):
147146
# This is the consumer stream
148147
s = dev.create_stream() if use_stream else None
149148

150-
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
149+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=s.handle if s else -1)
151150
self._check_view(view, in_arr, dev)
152151

153152
def _check_view(self, view, in_arr, dev):
@@ -179,7 +178,7 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream):
179178
# The usual path in `StridedMemoryView` prefers the DLPack interface
180179
# over __cuda_array_interface__, so we call `view_as_cai` directly
181180
# here so we can test the CAI code path.
182-
view = view_as_cai(in_arr, stream_ptr=s.handle if s else -1)
181+
view = StridedMemoryView.from_cuda_array_interface(in_arr, stream_ptr=s.handle if s else -1)
183182
self._check_view(view, in_arr, dev)
184183

185184
def _check_view(self, view, in_arr, dev):

0 commit comments

Comments
 (0)