@@ -88,20 +88,155 @@ cdef class StridedMemoryView:
8888 cdef DLTensor * dl_tensor
8989
9090 # Memoized properties
91- cdef tuple _shape
92- cdef tuple _strides
93- cdef bint _strides_init # Has the strides tuple been init'ed?
94- cdef object _dtype
95-
96- def __init__ (self , obj = None , stream_ptr = None ):
97- if obj is not None :
98- # populate self's attributes
99- if check_has_dlpack(obj):
100- view_as_dlpack(obj, stream_ptr, self )
101- else :
102- view_as_cai(obj, stream_ptr, self )
91+ cdef:
92+ tuple _shape
93+ tuple _strides
94+ # a `None` value for _strides has defined meaning in dlpack and
95+ # the cuda array interface, meaning C order, contiguous.
96+ #
97+ # this flag helps prevent unnecessary recompuation of _strides
98+ bint _strides_init
99+ object _dtype
100+
101+ def __init__ (
102+ self ,
103+ *,
104+ ptr: intptr_t ,
105+ device_id: int ,
106+ is_device_accessible: bint ,
107+ readonly: bint ,
108+ metadata: object ,
109+ exporting_obj: object ,
110+ dl_tensor: intptr_t = 0 ,
111+ ) -> None:
112+ self.ptr = ptr
113+ self.device_id = device_id
114+ self.is_device_accessible = is_device_accessible
115+ self.readonly = readonly
116+ self.metadata = metadata
117+ self.exporting_obj = exporting_obj
118+ self.dl_tensor = < DLTensor* > dl_tensor
119+ self._shape = None
120+ self._strides = None
121+ self._strides_init = False
122+ self._dtype = None
123+
124+ @classmethod
125+ def from_dlpack(cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
126+ cdef int dldevice , device_id
127+ cdef bint is_device_accessible , is_readonly
128+ is_device_accessible = False
129+ dldevice , device_id = obj.__dlpack_device__()
130+ if dldevice == _kDLCPU:
131+ assert device_id == 0
132+ device_id = - 1
133+ if stream_ptr is None:
134+ raise BufferError("stream = None is ambiguous with view()" )
135+ elif stream_ptr == -1:
136+ stream_ptr = None
137+ elif dldevice == _kDLCUDA:
138+ assert device_id >= 0
139+ is_device_accessible = True
140+ # no need to check other stream values , it's a pass-through
141+ if stream_ptr is None:
142+ raise BufferError("stream = None is ambiguous with view()" )
143+ elif dldevice in (_kDLCUDAHost , _kDLCUDAManaged ):
144+ is_device_accessible = True
145+ # just do a pass-through without any checks, as pinned/managed memory can be
146+ # accessed on both host and device
103147 else :
104- pass
148+ raise BufferError(" device not supported" )
149+
150+ cdef object capsule
151+ try :
152+ capsule = obj.__dlpack__(
153+ stream = int (stream_ptr) if stream_ptr else None ,
154+ max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
155+ except TypeError :
156+ capsule = obj.__dlpack__(
157+ stream = int (stream_ptr) if stream_ptr else None )
158+
159+ cdef void * data = NULL
160+ cdef DLTensor* dl_tensor
161+ cdef DLManagedTensorVersioned* dlm_tensor_ver
162+ cdef DLManagedTensor* dlm_tensor
163+ cdef const char * used_name
164+ if cpython.PyCapsule_IsValid(
165+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
166+ data = cpython.PyCapsule_GetPointer(
167+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
168+ dlm_tensor_ver = < DLManagedTensorVersioned* > data
169+ dl_tensor = & dlm_tensor_ver.dl_tensor
170+ is_readonly = (dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
171+ used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
172+ else :
173+ assert cpython.PyCapsule_IsValid(
174+ capsule, DLPACK_TENSOR_UNUSED_NAME)
175+ data = cpython.PyCapsule_GetPointer(
176+ capsule, DLPACK_TENSOR_UNUSED_NAME)
177+ dlm_tensor = < DLManagedTensor* > data
178+ dl_tensor = & dlm_tensor.dl_tensor
179+ is_readonly = False
180+ used_name = DLPACK_TENSOR_USED_NAME
181+
182+ cpython.PyCapsule_SetName(capsule, used_name)
183+
184+ return cls (
185+ ptr = < intptr_t> dl_tensor.data,
186+ device_id = int (device_id),
187+ is_device_accessible = is_device_accessible,
188+ readonly = is_readonly,
189+ metadata = capsule,
190+ exporting_obj = obj,
191+ dl_tensor = < intptr_t> dl_tensor,
192+ )
193+
194+ @classmethod
195+ def from_cuda_array_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
196+ cdef dict cai_data = obj.__cuda_array_interface__
197+ if cai_data["version"] < 3:
198+ raise BufferError("only CUDA Array Interface v3 or above is supported")
199+ if cai_data.get("mask") is not None:
200+ raise BufferError("mask is not supported")
201+ if stream_ptr is None:
202+ raise BufferError("stream = None is ambiguous with view()" )
203+
204+ cdef intptr_t producer_s , consumer_s
205+ stream_ptr = int (stream_ptr)
206+ if stream_ptr != -1:
207+ stream = cai_data.get(" stream" )
208+ if stream is not None:
209+ producer_s = < intptr_t> (stream)
210+ consumer_s = < intptr_t> (stream_ptr)
211+ assert producer_s > 0
212+ # establish stream order
213+ if producer_s != consumer_s:
214+ e = handle_return(driver.cuEventCreate(
215+ driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
216+ handle_return(driver.cuEventRecord(e , producer_s ))
217+ handle_return(driver.cuStreamWaitEvent(consumer_s , e , 0))
218+ handle_return(driver.cuEventDestroy(e ))
219+
220+ cdef intptr_t ptr = int (cai_data[" data" ][0 ])
221+ return cls(
222+ ptr = ptr,
223+ device_id = handle_return(
224+ driver.cuPointerGetAttribute(
225+ driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
226+ ptr
227+ )
228+ ),
229+ is_device_accessible = True ,
230+ readonly = cai_data[" data" ][1 ],
231+ metadata = cai_data,
232+ exporting_obj = obj,
233+ )
234+
235+ @classmethod
236+ def from_any_interface(cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
237+ if check_has_dlpack(obj ):
238+ return cls .from_dlpack(obj, stream_ptr)
239+ return cls .from_cuda_array_interface(obj, stream_ptr)
105240
106241 def __dealloc__ (self ):
107242 if self .dl_tensor == NULL :
@@ -121,7 +256,7 @@ cdef class StridedMemoryView:
121256 dlm_tensor.deleter(dlm_tensor)
122257
123258 @property
124- def shape (self ) -> tuple[int]:
259+ def shape (self ) -> tuple[int , ... ]:
125260 if self._shape is None:
126261 if self.exporting_obj is not None:
127262 if self.dl_tensor != NULL:
@@ -136,7 +271,7 @@ cdef class StridedMemoryView:
136271 return self._shape
137272
138273 @property
139- def strides(self ) -> Optional[tuple[int]]:
274+ def strides(self ) -> Optional[tuple[int , ... ]]:
140275 cdef int itemsize
141276 if self._strides_init is False:
142277 if self.exporting_obj is not None:
@@ -206,8 +341,7 @@ cdef bint check_has_dlpack(obj) except*:
206341
207342
208343cdef class _StridedMemoryViewProxy:
209-
210- cdef:
344+ cdef readonly:
211345 object obj
212346 bint has_dlpack
213347
@@ -217,82 +351,11 @@ cdef class _StridedMemoryViewProxy:
217351
218352 cpdef StridedMemoryView view(self , stream_ptr = None ):
219353 if self .has_dlpack:
220- return view_as_dlpack (self .obj, stream_ptr)
354+ return StridedMemoryView.from_dlpack (self .obj, stream_ptr)
221355 else :
222- return view_as_cai (self .obj, stream_ptr)
356+ return StridedMemoryView.from_cuda_array_interface (self .obj, stream_ptr)
223357
224358
225- cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
226- cdef int dldevice, device_id
227- cdef bint is_device_accessible, is_readonly
228- is_device_accessible = False
229- dldevice, device_id = obj.__dlpack_device__()
230- if dldevice == _kDLCPU:
231- assert device_id == 0
232- device_id = - 1
233- if stream_ptr is None :
234- raise BufferError(" stream=None is ambiguous with view()" )
235- elif stream_ptr == - 1 :
236- stream_ptr = None
237- elif dldevice == _kDLCUDA:
238- assert device_id >= 0
239- is_device_accessible = True
240- # no need to check other stream values, it's a pass-through
241- if stream_ptr is None :
242- raise BufferError(" stream=None is ambiguous with view()" )
243- elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
244- is_device_accessible = True
245- # just do a pass-through without any checks, as pinned/managed memory can be
246- # accessed on both host and device
247- else :
248- raise BufferError(" device not supported" )
249-
250- cdef object capsule
251- try :
252- capsule = obj.__dlpack__(
253- stream = int (stream_ptr) if stream_ptr else None ,
254- max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
255- except TypeError :
256- capsule = obj.__dlpack__(
257- stream = int (stream_ptr) if stream_ptr else None )
258-
259- cdef void * data = NULL
260- cdef DLTensor* dl_tensor
261- cdef DLManagedTensorVersioned* dlm_tensor_ver
262- cdef DLManagedTensor* dlm_tensor
263- cdef const char * used_name
264- if cpython.PyCapsule_IsValid(
265- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
266- data = cpython.PyCapsule_GetPointer(
267- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
268- dlm_tensor_ver = < DLManagedTensorVersioned* > data
269- dl_tensor = & dlm_tensor_ver.dl_tensor
270- is_readonly = bool ((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0 )
271- used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
272- elif cpython.PyCapsule_IsValid(
273- capsule, DLPACK_TENSOR_UNUSED_NAME):
274- data = cpython.PyCapsule_GetPointer(
275- capsule, DLPACK_TENSOR_UNUSED_NAME)
276- dlm_tensor = < DLManagedTensor* > data
277- dl_tensor = & dlm_tensor.dl_tensor
278- is_readonly = False
279- used_name = DLPACK_TENSOR_USED_NAME
280- else :
281- assert False
282-
283- cpython.PyCapsule_SetName(capsule, used_name)
284-
285- cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
286- buf.dl_tensor = dl_tensor
287- buf.metadata = capsule
288- buf.ptr = < intptr_t> (dl_tensor.data)
289- buf.device_id = device_id
290- buf.is_device_accessible = is_device_accessible
291- buf.readonly = is_readonly
292- buf.exporting_obj = obj
293-
294- return buf
295-
296359
297360cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
298361 cdef int bits = dtype.bits
@@ -354,46 +417,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354417 return numpy.dtype(np_dtype)
355418
356419
357- # Also generate for Python so we can test this code path
358- cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
359- cdef dict cai_data = obj.__cuda_array_interface__
360- if cai_data[" version" ] < 3 :
361- raise BufferError(" only CUDA Array Interface v3 or above is supported" )
362- if cai_data.get(" mask" ) is not None :
363- raise BufferError(" mask is not supported" )
364- if stream_ptr is None :
365- raise BufferError(" stream=None is ambiguous with view()" )
366-
367- cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
368- buf.exporting_obj = obj
369- buf.metadata = cai_data
370- buf.dl_tensor = NULL
371- buf.ptr, buf.readonly = cai_data[" data" ]
372- buf.is_device_accessible = True
373- buf.device_id = handle_return(
374- driver.cuPointerGetAttribute(
375- driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
376- buf.ptr))
377-
378- cdef intptr_t producer_s, consumer_s
379- stream_ptr = int (stream_ptr)
380- if stream_ptr != - 1 :
381- stream = cai_data.get(" stream" )
382- if stream is not None :
383- producer_s = < intptr_t> (stream)
384- consumer_s = < intptr_t> (stream_ptr)
385- assert producer_s > 0
386- # establish stream order
387- if producer_s != consumer_s:
388- e = handle_return(driver.cuEventCreate(
389- driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
390- handle_return(driver.cuEventRecord(e, producer_s))
391- handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0 ))
392- handle_return(driver.cuEventDestroy(e))
393-
394- return buf
395-
396-
397420def args_viewable_as_strided_memory (tuple arg_indices ):
398421 """
399422 Decorator to create proxy objects to :obj:`StridedMemoryView` for the
0 commit comments