@@ -88,155 +88,20 @@ cdef class StridedMemoryView:
8888 cdef DLTensor * dl_tensor
8989
9090 # Memoized properties
91- cdef:
92- tuple _shape
93- tuple _strides
94- # a `None` value for _strides has defined meaning in dlpack and
95- # the cuda array interface, meaning C order, contiguous.
96- #
97- # this flag helps prevent unnecessary recompuation of _strides
98- bint _strides_init
99- object _dtype
100-
101- def __init__ (
102- self ,
103- *,
104- ptr: intptr_t ,
105- device_id: int ,
106- is_device_accessible: bint ,
107- readonly: bint ,
108- metadata: object ,
109- exporting_obj: object ,
110- dl_tensor: intptr_t = 0 ,
111- ) -> None:
112- self.ptr = ptr
113- self.device_id = device_id
114- self.is_device_accessible = is_device_accessible
115- self.readonly = readonly
116- self.metadata = metadata
117- self.exporting_obj = exporting_obj
118- self.dl_tensor = < DLTensor* > dl_tensor
119- self._shape = None
120- self._strides = None
121- self._strides_init = False
122- self._dtype = None
123-
124- @classmethod
125- def from_dlpack(cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
126- cdef int dldevice , device_id
127- cdef bint is_device_accessible , is_readonly
128- is_device_accessible = False
129- dldevice , device_id = obj.__dlpack_device__()
130- if dldevice == _kDLCPU:
131- assert device_id == 0
132- device_id = - 1
133- if stream_ptr is None:
134- raise BufferError("stream = None is ambiguous with view()" )
135- elif stream_ptr == -1:
136- stream_ptr = None
137- elif dldevice == _kDLCUDA:
138- assert device_id >= 0
139- is_device_accessible = True
140- # no need to check other stream values , it's a pass-through
141- if stream_ptr is None:
142- raise BufferError("stream = None is ambiguous with view()" )
143- elif dldevice in (_kDLCUDAHost , _kDLCUDAManaged ):
144- is_device_accessible = True
145- # just do a pass-through without any checks, as pinned/managed memory can be
146- # accessed on both host and device
147- else :
148- raise BufferError(" device not supported" )
149-
150- cdef object capsule
151- try :
152- capsule = obj.__dlpack__(
153- stream = int (stream_ptr) if stream_ptr else None ,
154- max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
155- except TypeError :
156- capsule = obj.__dlpack__(
157- stream = int (stream_ptr) if stream_ptr else None )
158-
159- cdef void * data = NULL
160- cdef DLTensor* dl_tensor
161- cdef DLManagedTensorVersioned* dlm_tensor_ver
162- cdef DLManagedTensor* dlm_tensor
163- cdef const char * used_name
164- if cpython.PyCapsule_IsValid(
165- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
166- data = cpython.PyCapsule_GetPointer(
167- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
168- dlm_tensor_ver = < DLManagedTensorVersioned* > data
169- dl_tensor = & dlm_tensor_ver.dl_tensor
170- is_readonly = (dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
171- used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
91+ cdef tuple _shape
92+ cdef tuple _strides
93+ cdef bint _strides_init # Has the strides tuple been init'ed?
94+ cdef object _dtype
95+
96+ def __init__ (self , obj = None , stream_ptr = None ):
97+ if obj is not None :
98+ # populate self's attributes
99+ if check_has_dlpack(obj):
100+ view_as_dlpack(obj, stream_ptr, self )
101+ else :
102+ view_as_cai(obj, stream_ptr, self )
172103 else :
173- assert cpython.PyCapsule_IsValid(
174- capsule, DLPACK_TENSOR_UNUSED_NAME)
175- data = cpython.PyCapsule_GetPointer(
176- capsule, DLPACK_TENSOR_UNUSED_NAME)
177- dlm_tensor = < DLManagedTensor* > data
178- dl_tensor = & dlm_tensor.dl_tensor
179- is_readonly = False
180- used_name = DLPACK_TENSOR_USED_NAME
181-
182- cpython.PyCapsule_SetName(capsule, used_name)
183-
184- return cls (
185- ptr = < intptr_t> dl_tensor.data,
186- device_id = int (device_id),
187- is_device_accessible = is_device_accessible,
188- readonly = is_readonly,
189- metadata = capsule,
190- exporting_obj = obj,
191- dl_tensor = < intptr_t> dl_tensor,
192- )
193-
194- @classmethod
195- def from_cuda_array_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
196- cdef dict cai_data = obj.__cuda_array_interface__
197- if cai_data["version"] < 3:
198- raise BufferError("only CUDA Array Interface v3 or above is supported")
199- if cai_data.get("mask") is not None:
200- raise BufferError("mask is not supported")
201- if stream_ptr is None:
202- raise BufferError("stream = None is ambiguous with view()" )
203-
204- cdef intptr_t producer_s , consumer_s
205- stream_ptr = int (stream_ptr)
206- if stream_ptr != -1:
207- stream = cai_data.get(" stream" )
208- if stream is not None:
209- producer_s = < intptr_t> (stream)
210- consumer_s = < intptr_t> (stream_ptr)
211- assert producer_s > 0
212- # establish stream order
213- if producer_s != consumer_s:
214- e = handle_return(driver.cuEventCreate(
215- driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
216- handle_return(driver.cuEventRecord(e , producer_s ))
217- handle_return(driver.cuStreamWaitEvent(consumer_s , e , 0))
218- handle_return(driver.cuEventDestroy(e ))
219-
220- cdef intptr_t ptr = int (cai_data[" data" ][0 ])
221- return cls(
222- ptr = ptr,
223- device_id = handle_return(
224- driver.cuPointerGetAttribute(
225- driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
226- ptr
227- )
228- ),
229- is_device_accessible = True ,
230- readonly = cai_data[" data" ][1 ],
231- metadata = cai_data,
232- exporting_obj = obj,
233- )
234-
235- @classmethod
236- def from_any_interface(cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
237- if check_has_dlpack(obj ):
238- return cls .from_dlpack(obj, stream_ptr)
239- return cls .from_cuda_array_interface(obj, stream_ptr)
104+ pass
240105
241106 def __dealloc__ (self ):
242107 if self .dl_tensor == NULL :
@@ -256,7 +121,7 @@ cdef class StridedMemoryView:
256121 dlm_tensor.deleter(dlm_tensor)
257122
258123 @property
259- def shape (self ) -> tuple[int , ... ]:
124+ def shape (self ) -> tuple[int]:
260125 if self._shape is None:
261126 if self.exporting_obj is not None:
262127 if self.dl_tensor != NULL:
@@ -271,7 +136,7 @@ cdef class StridedMemoryView:
271136 return self._shape
272137
273138 @property
274- def strides(self ) -> Optional[tuple[int , ... ]]:
139+ def strides(self ) -> Optional[tuple[int]]:
275140 cdef int itemsize
276141 if self._strides_init is False:
277142 if self.exporting_obj is not None:
@@ -341,7 +206,8 @@ cdef bint check_has_dlpack(obj) except*:
341206
342207
343208cdef class _StridedMemoryViewProxy:
344- cdef readonly:
209+
210+ cdef:
345211 object obj
346212 bint has_dlpack
347213
@@ -351,11 +217,82 @@ cdef class _StridedMemoryViewProxy:
351217
352218 cpdef StridedMemoryView view(self , stream_ptr = None ):
353219 if self .has_dlpack:
354- return StridedMemoryView.from_dlpack (self .obj, stream_ptr)
220+ return view_as_dlpack (self .obj, stream_ptr)
355221 else :
356- return StridedMemoryView.from_cuda_array_interface (self .obj, stream_ptr)
222+ return view_as_cai (self .obj, stream_ptr)
357223
358224
225+ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
226+ cdef int dldevice, device_id
227+ cdef bint is_device_accessible, is_readonly
228+ is_device_accessible = False
229+ dldevice, device_id = obj.__dlpack_device__()
230+ if dldevice == _kDLCPU:
231+ assert device_id == 0
232+ device_id = - 1
233+ if stream_ptr is None :
234+ raise BufferError(" stream=None is ambiguous with view()" )
235+ elif stream_ptr == - 1 :
236+ stream_ptr = None
237+ elif dldevice == _kDLCUDA:
238+ assert device_id >= 0
239+ is_device_accessible = True
240+ # no need to check other stream values, it's a pass-through
241+ if stream_ptr is None :
242+ raise BufferError(" stream=None is ambiguous with view()" )
243+ elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
244+ is_device_accessible = True
245+ # just do a pass-through without any checks, as pinned/managed memory can be
246+ # accessed on both host and device
247+ else :
248+ raise BufferError(" device not supported" )
249+
250+ cdef object capsule
251+ try :
252+ capsule = obj.__dlpack__(
253+ stream = int (stream_ptr) if stream_ptr else None ,
254+ max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
255+ except TypeError :
256+ capsule = obj.__dlpack__(
257+ stream = int (stream_ptr) if stream_ptr else None )
258+
259+ cdef void * data = NULL
260+ cdef DLTensor* dl_tensor
261+ cdef DLManagedTensorVersioned* dlm_tensor_ver
262+ cdef DLManagedTensor* dlm_tensor
263+ cdef const char * used_name
264+ if cpython.PyCapsule_IsValid(
265+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
266+ data = cpython.PyCapsule_GetPointer(
267+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
268+ dlm_tensor_ver = < DLManagedTensorVersioned* > data
269+ dl_tensor = & dlm_tensor_ver.dl_tensor
270+ is_readonly = bool ((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0 )
271+ used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
272+ elif cpython.PyCapsule_IsValid(
273+ capsule, DLPACK_TENSOR_UNUSED_NAME):
274+ data = cpython.PyCapsule_GetPointer(
275+ capsule, DLPACK_TENSOR_UNUSED_NAME)
276+ dlm_tensor = < DLManagedTensor* > data
277+ dl_tensor = & dlm_tensor.dl_tensor
278+ is_readonly = False
279+ used_name = DLPACK_TENSOR_USED_NAME
280+ else :
281+ assert False
282+
283+ cpython.PyCapsule_SetName(capsule, used_name)
284+
285+ cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
286+ buf.dl_tensor = dl_tensor
287+ buf.metadata = capsule
288+ buf.ptr = < intptr_t> (dl_tensor.data)
289+ buf.device_id = device_id
290+ buf.is_device_accessible = is_device_accessible
291+ buf.readonly = is_readonly
292+ buf.exporting_obj = obj
293+
294+ return buf
295+
359296
360297cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
361298 cdef int bits = dtype.bits
@@ -417,6 +354,46 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
417354 return numpy.dtype(np_dtype)
418355
419356
357+ # Also generate for Python so we can test this code path
358+ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
359+ cdef dict cai_data = obj.__cuda_array_interface__
360+ if cai_data[" version" ] < 3 :
361+ raise BufferError(" only CUDA Array Interface v3 or above is supported" )
362+ if cai_data.get(" mask" ) is not None :
363+ raise BufferError(" mask is not supported" )
364+ if stream_ptr is None :
365+ raise BufferError(" stream=None is ambiguous with view()" )
366+
367+ cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
368+ buf.exporting_obj = obj
369+ buf.metadata = cai_data
370+ buf.dl_tensor = NULL
371+ buf.ptr, buf.readonly = cai_data[" data" ]
372+ buf.is_device_accessible = True
373+ buf.device_id = handle_return(
374+ driver.cuPointerGetAttribute(
375+ driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
376+ buf.ptr))
377+
378+ cdef intptr_t producer_s, consumer_s
379+ stream_ptr = int (stream_ptr)
380+ if stream_ptr != - 1 :
381+ stream = cai_data.get(" stream" )
382+ if stream is not None :
383+ producer_s = < intptr_t> (stream)
384+ consumer_s = < intptr_t> (stream_ptr)
385+ assert producer_s > 0
386+ # establish stream order
387+ if producer_s != consumer_s:
388+ e = handle_return(driver.cuEventCreate(
389+ driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
390+ handle_return(driver.cuEventRecord(e, producer_s))
391+ handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0 ))
392+ handle_return(driver.cuEventDestroy(e))
393+
394+ return buf
395+
396+
420397def args_viewable_as_strided_memory (tuple arg_indices ):
421398 """
422399 Decorator to create proxy objects to :obj:`StridedMemoryView` for the
0 commit comments