@@ -19,7 +19,7 @@ from cuda.core._utils import handle_return
1919
2020
2121@cython.dataclasses.dataclass
22- cdef class GPUMemoryView :
22+ cdef class StridedMemoryView :
2323
2424 # TODO: switch to use Cython's cdef typing?
2525 ptr: int = None
@@ -43,14 +43,14 @@ cdef class GPUMemoryView:
4343 pass
4444
4545 def __repr__ (self ):
46- return (f" GPUMemoryView (ptr={self.ptr},\n "
47- + f" shape={self.shape},\n "
48- + f" strides={self.strides},\n "
49- + f" dtype={get_simple_repr(self.dtype)},\n "
50- + f" device_id={self.device_id},\n "
51- + f" device_accessible={self.device_accessible},\n "
52- + f" readonly={self.readonly},\n "
53- + f" obj={get_simple_repr(self.obj)})" )
46+ return (f" StridedMemoryView (ptr={self.ptr},\n "
47+ + f" shape={self.shape},\n "
48+ + f" strides={self.strides},\n "
49+ + f" dtype={get_simple_repr(self.dtype)},\n "
50+ + f" device_id={self.device_id},\n "
51+ + f" device_accessible={self.device_accessible},\n "
52+ + f" readonly={self.readonly},\n "
53+ + f" obj={get_simple_repr(self.obj)})" )
5454
5555
5656cdef str get_simple_repr(obj):
@@ -80,7 +80,7 @@ cdef bint check_has_dlpack(obj) except*:
8080 return has_dlpack
8181
8282
83- cdef class _GPUMemoryViewProxy :
83+ cdef class _StridedMemoryViewProxy :
8484
8585 cdef:
8686 object obj
@@ -90,14 +90,14 @@ cdef class _GPUMemoryViewProxy:
9090 self .obj = obj
9191 self .has_dlpack = check_has_dlpack(obj)
9292
93- cpdef GPUMemoryView view(self , stream_ptr = None ):
93+ cpdef StridedMemoryView view(self , stream_ptr = None ):
9494 if self .has_dlpack:
9595 return view_as_dlpack(self .obj, stream_ptr)
9696 else :
9797 return view_as_cai(self .obj, stream_ptr)
9898
9999
100- cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
100+ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
101101 cdef int dldevice, device_id, i
102102 cdef bint device_accessible, versioned, is_readonly
103103 dldevice, device_id = obj.__dlpack_device__()
@@ -160,7 +160,7 @@ cdef GPUMemoryView view_as_dlpack(obj, stream_ptr, view=None):
160160 dl_tensor = & dlm_tensor.dl_tensor
161161 is_readonly = False
162162
163- cdef GPUMemoryView buf = GPUMemoryView () if view is None else view
163+ cdef StridedMemoryView buf = StridedMemoryView () if view is None else view
164164 buf.ptr = < intptr_t> (dl_tensor.data)
165165 buf.shape = tuple (int (dl_tensor.shape[i]) for i in range (dl_tensor.ndim))
166166 if dl_tensor.strides:
@@ -242,7 +242,7 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
242242 return numpy.dtype(np_dtype)
243243
244244
245- cdef GPUMemoryView view_as_cai(obj, stream_ptr, view = None ):
245+ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
246246 cdef dict cai_data = obj.__cuda_array_interface__
247247 if cai_data[" version" ] < 3 :
248248 raise BufferError(" only CUDA Array Interface v3 or above is supported" )
@@ -251,7 +251,7 @@ cdef GPUMemoryView view_as_cai(obj, stream_ptr, view=None):
251251 if stream_ptr is None :
252252 raise BufferError(" stream=None is ambiguous with view()" )
253253
254- cdef GPUMemoryView buf = GPUMemoryView () if view is None else view
254+ cdef StridedMemoryView buf = StridedMemoryView () if view is None else view
255255 buf.obj = obj
256256 buf.ptr, buf.readonly = cai_data[" data" ]
257257 buf.shape = cai_data[" shape" ]
@@ -291,7 +291,7 @@ def viewable(tuple arg_indices):
291291 args = list (args)
292292 cdef int idx
293293 for idx in arg_indices:
294- args[idx] = _GPUMemoryViewProxy (args[idx])
294+ args[idx] = _StridedMemoryViewProxy (args[idx])
295295 return func(* args, ** kwargs)
296296 return wrapped_func
297297 return wrapped_func_with_indices
0 commit comments