@@ -31,6 +31,17 @@ cdef class GPUMemoryView:
3131 readonly: bool = None
3232 obj: Any = None
3333
34+ def __init__ (self , obj = None , stream_ptr = None ):
35+ if obj is not None :
36+ # populate self's attributes
37+ if check_has_dlpack(obj):
38+ view_as_dlpack(obj, stream_ptr, self )
39+ else :
40+ view_as_cai(obj, stream_ptr, self )
41+ else :
42+ # default construct
43+ pass
44+
3445 def __repr__ (self ):
3546 return (f" GPUMemoryView(ptr={self.ptr},\n "
3647 + f" shape={self.shape},\n "
@@ -57,22 +68,27 @@ cdef str get_simple_repr(obj):
5768 return obj_repr
5869
5970
71+ cdef bint check_has_dlpack(obj) except * :
72+ cdef bint has_dlpack
73+ if hasattr (obj, " __dlpack__" ) and hasattr (obj, " __dlpack_device__" ):
74+ has_dlpack = True
75+ elif hasattr (obj, " __cuda_array_interface__" ):
76+ has_dlpack = False
77+ else :
78+ raise RuntimeError (
79+ " the input object does not support any data exchange protocol" )
80+ return has_dlpack
81+
82+
6083cdef class _GPUMemoryViewProxy:
6184
6285 cdef:
6386 object obj
6487 bint has_dlpack
6588
6689 def __init__ (self , obj ):
67- if hasattr (obj, " __dlpack__" ) and hasattr (obj, " __dlpack_device__" ):
68- has_dlpack = True
69- elif hasattr (obj, " __cuda_array_interface__" ):
70- has_dlpack = False
71- else :
72- raise RuntimeError (
73- " the input object does not support any data exchange protocol" )
7490 self .obj = obj
75- self .has_dlpack = has_dlpack
91+ self .has_dlpack = check_has_dlpack(obj)
7692
7793 cpdef GPUMemoryView view(self , stream_ptr = None ):
7894 if self .has_dlpack:
@@ -81,7 +97,7 @@ cdef class _GPUMemoryViewProxy:
8197 return view_as_cai(self .obj, stream_ptr)
8298
8399
84- cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
100+ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
85101 cdef int dldevice, device_id, i
86102 cdef bint device_accessible, versioned, is_readonly
87103 dldevice, device_id = obj.__dlpack_device__()
@@ -144,7 +160,7 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
144160 dl_tensor = & dlm_tensor.dl_tensor
145161 is_readonly = False
146162
147- cdef GPUMemoryView buf = GPUMemoryView()
163+ cdef GPUMemoryView buf = GPUMemoryView() if view is None else view
148164 buf.ptr = < intptr_t> (dl_tensor.data)
149165 buf.shape = tuple (int (dl_tensor.shape[i]) for i in range (dl_tensor.ndim))
150166 if dl_tensor.strides:
@@ -226,7 +242,7 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
226242 return numpy.dtype(np_dtype)
227243
228244
229- cdef GPUMemoryView view_as_cai(obj, stream_ptr):
245+ cdef GPUMemoryView view_as_cai(obj, stream_ptr, view = None ):
230246 cdef dict cai_data = obj.__cuda_array_interface__
231247 if cai_data[" version" ] < 3 :
232248 raise BufferError(" only CUDA Array Interface v3 or above is supported" )
@@ -235,7 +251,7 @@ cdef GPUMemoryView view_as_cai(obj, stream_ptr):
235251 if stream_ptr is None :
236252 raise BufferError(" stream=None is ambiguous with view()" )
237253
238- cdef GPUMemoryView buf = GPUMemoryView()
254+ cdef GPUMemoryView buf = GPUMemoryView() if view is None else view
239255 buf.obj = obj
240256 buf.ptr, buf.readonly = cai_data[" data" ]
241257 buf.shape = cai_data[" shape" ]
0 commit comments