33# SPDX-License-Identifier: Apache-2.0
44
55from enum import IntEnum
6- from ._layout cimport StridedLayout
76
87
98cdef void pycapsule_deleter(object capsule) noexcept:
@@ -34,16 +33,6 @@ cdef void deleter(DLManagedTensor* tensor) noexcept with gil:
3433 stdlib.free(tensor)
3534
3635
37- cdef void cleanup(DLManagedTensor* tensor) noexcept with gil:
38- if tensor:
39- if tensor.dl_tensor.shape:
40- stdlib.free(tensor.dl_tensor.shape)
41- if tensor.manager_ctx:
42- cpython.Py_DECREF(< object > tensor.manager_ctx)
43- tensor.manager_ctx = NULL
44- stdlib.free(tensor)
45-
46-
4736cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
4837 stdlib.free(tensor.dl_tensor.shape)
4938 if tensor.manager_ctx:
@@ -52,17 +41,43 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
5241 stdlib.free(tensor)
5342
5443
55- cdef void cleanup_versioned(DLManagedTensorVersioned* tensor) noexcept with gil:
56- if tensor:
57- if tensor.dl_tensor.shape:
58- stdlib.free(tensor.dl_tensor.shape)
59- if tensor.manager_ctx:
60- cpython.Py_DECREF(< object > tensor.manager_ctx)
61- tensor.manager_ctx = NULL
62- stdlib.free(tensor)
44+ cpdef object make_py_capsule(object buf, bint versioned):
45+ cdef DLManagedTensor* dlm_tensor
46+ cdef DLManagedTensorVersioned* dlm_tensor_ver
47+ cdef DLTensor* dl_tensor
48+ cdef void * tensor_ptr
49+ cdef const char * capsule_name
6350
51+ if versioned:
52+ dlm_tensor_ver = < DLManagedTensorVersioned* > (
53+ stdlib.malloc(sizeof(DLManagedTensorVersioned)))
54+ dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
55+ dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
56+ dlm_tensor_ver.manager_ctx = < void * > buf
57+ dlm_tensor_ver.deleter = versioned_deleter
58+ dlm_tensor_ver.flags = 0
59+ dl_tensor = & dlm_tensor_ver.dl_tensor
60+ tensor_ptr = dlm_tensor_ver
61+ capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
62+ else :
63+ dlm_tensor = < DLManagedTensor* > (
64+ stdlib.malloc(sizeof(DLManagedTensor)))
65+ dl_tensor = & dlm_tensor.dl_tensor
66+ dlm_tensor.manager_ctx = < void * > buf
67+ dlm_tensor.deleter = deleter
68+ tensor_ptr = dlm_tensor
69+ capsule_name = DLPACK_TENSOR_UNUSED_NAME
70+
71+ dl_tensor.data = < void * >< intptr_t> (int (buf.handle))
72+ dl_tensor.ndim = 1
73+ cdef int64_t* shape_strides = \
74+ < int64_t* > stdlib.malloc(sizeof(int64_t) * 2 )
75+ shape_strides[0 ] = < int64_t> buf.size
76+ shape_strides[1 ] = 1 # redundant
77+ dl_tensor.shape = shape_strides
78+ dl_tensor.strides = NULL
79+ dl_tensor.byte_offset = 0
6480
65- cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except - 1 :
6681 cdef DLDevice* device = & dl_tensor.device
6782 # buf should be a Buffer instance
6883 if buf.is_device_accessible and not buf.is_host_accessible:
@@ -76,103 +91,14 @@ cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except
7691 device.device_id = 0
7792 else : # not buf.is_device_accessible and not buf.is_host_accessible
7893 raise BufferError(" invalid buffer" )
79- return 0
80-
81-
82- cdef inline int _setup_dl_tensor_layout(DLTensor* dl_tensor, object buf, StridedLayout layout) except - 1 :
83- cdef int64_t* shape_strides = NULL
84- cdef int ndim
85- if layout is None :
86- dl_tensor.ndim = 1
87- shape_strides = < int64_t* > stdlib.malloc(sizeof(int64_t) * 2 )
88- shape_strides[0 ] = < int64_t> buf.size
89- shape_strides[1 ] = 1 # redundant
90- dl_tensor.shape = shape_strides
91- dl_tensor.strides = NULL
92- dl_tensor.byte_offset = 0
93- else :
94- ndim = layout.ndim
95- dl_tensor.ndim = ndim
96- shape_strides = < int64_t* > stdlib.malloc(sizeof(int64_t) * ndim * 2 )
97- dl_tensor.shape = shape_strides
98- for i in range (ndim):
99- shape_strides[i] = layout.base.shape[i]
100- if layout.base.strides == NULL :
101- dl_tensor.strides = NULL
102- else :
103- dl_tensor.strides = shape_strides + ndim
104- for i in range (ndim):
105- dl_tensor.strides[i] = layout.base.strides[i]
106- dl_tensor.byte_offset = 0
107- return 0
108-
109-
110- cdef inline int _setup_dl_tensor_dtype(DLTensor* dl_tensor, object dtype) except - 1 :
111- cdef DLDataType* dl_dtype = & dl_tensor.dtype
112- if dtype is None :
113- dl_dtype.code = < uint8_t> kDLInt
114- dl_dtype.lanes = < uint16_t> 1
115- dl_dtype.bits = < uint8_t> 8
116- return 0
117- cdef uint8_t code
118- cdef uint8_t bits
119- cdef uint16_t lanes
120- code, bits, lanes = dtype
121- dl_dtype.code = code
122- dl_dtype.bits = bits
123- dl_dtype.lanes = lanes
124- return 0
125-
126-
127- cpdef object make_py_capsule(
128- object buf,
129- bint versioned,
130- intptr_t data_ptr,
131- StridedLayout layout = None ,
132- object dtype = None ,
133- ):
134- cdef DLManagedTensor* dlm_tensor = NULL
135- cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
136- cdef DLTensor* dl_tensor = NULL
137- cdef void * tensor_ptr
138- cdef const char * capsule_name
13994
140- try :
141- if versioned:
142- dlm_tensor_ver = < DLManagedTensorVersioned* > (
143- stdlib.malloc(sizeof(DLManagedTensorVersioned)))
144- dlm_tensor_ver.dl_tensor.shape = NULL
145- dl_tensor = & dlm_tensor_ver.dl_tensor
146- dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
147- dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
148- cpython.Py_INCREF(buf)
149- dlm_tensor_ver.manager_ctx = < void * > buf
150- dlm_tensor_ver.deleter = versioned_deleter
151- dlm_tensor_ver.flags = 0
152- tensor_ptr = dlm_tensor_ver
153- capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
154- else :
155- dlm_tensor = < DLManagedTensor* > (
156- stdlib.malloc(sizeof(DLManagedTensor)))
157- dlm_tensor.dl_tensor.shape = NULL
158- dl_tensor = & dlm_tensor.dl_tensor
159- cpython.Py_INCREF(buf)
160- dlm_tensor.manager_ctx = < void * > buf
161- dlm_tensor.deleter = deleter
162- tensor_ptr = dlm_tensor
163- capsule_name = DLPACK_TENSOR_UNUSED_NAME
164-
165- dl_tensor.data = < void * > data_ptr
166-
167- _setup_dl_tensor_device(dl_tensor, buf)
168- _setup_dl_tensor_layout(dl_tensor, buf, layout)
169- _setup_dl_tensor_dtype(dl_tensor, dtype)
170-
171- return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
172- except :
173- cleanup(dlm_tensor)
174- cleanup_versioned(dlm_tensor_ver)
175- raise
95+ cdef DLDataType* dtype = & dl_tensor.dtype
96+ dtype.code = < uint8_t> kDLInt
97+ dtype.lanes = < uint16_t> 1
98+ dtype.bits = < uint8_t> 8
99+
100+ cpython.Py_INCREF(buf)
101+ return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
176102
177103
178104class DLDeviceType (IntEnum ):
0 commit comments