@@ -9,8 +9,14 @@ from ._dlpack cimport *
99import functools
1010from typing import Any, Optional
1111
12+ from cuda import cuda
1213import numpy
1314
15+ from cuda.py._utils import handle_return
16+
17+
18+ # TODO(leofang): support NumPy structured dtypes
19+
1420
1521@cython.dataclasses.dataclass
1622cdef class GPUMemoryView:
@@ -37,6 +43,7 @@ cdef class GPUMemoryView:
3743
3844
3945cdef str get_simple_repr(obj):
46+ # TODO: better handling in np.dtype objects
4047 cdef object obj_class
4148 cdef str obj_repr
4249 if isinstance (obj, type ):
@@ -71,8 +78,7 @@ cdef class _GPUMemoryViewProxy:
7178 if self .has_dlpack:
7279 return view_as_dlpack(self .obj, stream_ptr)
7380 else :
74- # TODO: Support CAI
75- raise NotImplementedError (" TODO" )
81+ return view_as_cai(self .obj, stream_ptr)
7682
7783
7884cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
@@ -216,7 +222,49 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
216222 else :
217223 raise TypeError (' Unsupported dtype. dtype code: {}' .format(dtype.code))
218224
219- return np_dtype
225+ # We want the dtype object not just the type object
226+ return numpy.dtype(np_dtype)
227+
228+
229+ cdef GPUMemoryView view_as_cai(obj, stream_ptr):
230+ cdef dict cai_data = obj.__cuda_array_interface__
231+ if cai_data[" version" ] < 3 :
232+ raise BufferError(" only CUDA Array Interface v3 or above is supported" )
233+ if cai_data.get(" mask" ) is not None :
234+ raise BufferError(" mask is not supported" )
235+ if stream_ptr is None :
236+ raise BufferError(" stream=None is ambiguous with view()" )
237+
238+ cdef GPUMemoryView buf = GPUMemoryView()
239+ buf.obj = obj
240+ buf.ptr, buf.readonly = cai_data[" data" ]
241+ buf.shape = cai_data[" shape" ]
242+ # TODO: this only works for built-in numeric types
243+ buf.dtype = numpy.dtype(cai_data[" typestr" ])
244+ buf.strides = cai_data.get(" strides" )
245+ if buf.strides is not None :
246+ # convert to counts
247+ buf.strides = tuple (s // buf.dtype.itemsize for s in buf.strides)
248+ buf.device_accessible = True
249+ buf.device_id = handle_return(
250+ cuda.cuPointerGetAttribute(
251+ cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
252+ buf.ptr))
253+
254+ cdef intptr_t producer_s, consumer_s
255+ stream = cai_data.get(" stream" )
256+ if stream is not None :
257+ producer_s = < intptr_t> (stream)
258+ consumer_s = < intptr_t> (stream_ptr)
259+ assert producer_s > 0
260+ # establish stream order
261+ if producer_s != consumer_s:
262+ e = handle_return(cuda.cuEventCreate(
263+ cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING))
264+ handle_return(cuda.cuEventRecord(e, producer_s))
265+ handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0 ))
266+
267+ return buf
220268
221269
222270def viewable (tuple arg_indices ):
0 commit comments