33# SPDX-License-Identifier: Apache-2.0
44
55from enum import IntEnum
6+ from ._layout cimport StridedLayout
67
78
89cdef void pycapsule_deleter(object capsule) noexcept:
@@ -33,6 +34,16 @@ cdef void deleter(DLManagedTensor* tensor) noexcept with gil:
3334 stdlib.free(tensor)
3435
3536
37+ cdef void cleanup(DLManagedTensor* tensor) noexcept with gil:
38+ if tensor:
39+ if tensor.dl_tensor.shape:
40+ stdlib.free(tensor.dl_tensor.shape)
41+ if tensor.manager_ctx:
42+ cpython.Py_DECREF(< object > tensor.manager_ctx)
43+ tensor.manager_ctx = NULL
44+ stdlib.free(tensor)
45+
46+
3647cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
3748 stdlib.free(tensor.dl_tensor.shape)
3849 if tensor.manager_ctx:
@@ -41,43 +52,17 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
4152 stdlib.free(tensor)
4253
4354
44- cpdef object make_py_capsule(object buf, bint versioned):
45- cdef DLManagedTensor* dlm_tensor
46- cdef DLManagedTensorVersioned* dlm_tensor_ver
47- cdef DLTensor* dl_tensor
48- cdef void * tensor_ptr
49- cdef const char * capsule_name
55+ cdef void cleanup_versioned(DLManagedTensorVersioned* tensor) noexcept with gil:
56+ if tensor:
57+ if tensor.dl_tensor.shape:
58+ stdlib.free(tensor.dl_tensor.shape)
59+ if tensor.manager_ctx:
60+ cpython.Py_DECREF(< object > tensor.manager_ctx)
61+ tensor.manager_ctx = NULL
62+ stdlib.free(tensor)
5063
51- if versioned:
52- dlm_tensor_ver = < DLManagedTensorVersioned* > (
53- stdlib.malloc(sizeof(DLManagedTensorVersioned)))
54- dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
55- dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
56- dlm_tensor_ver.manager_ctx = < void * > buf
57- dlm_tensor_ver.deleter = versioned_deleter
58- dlm_tensor_ver.flags = 0
59- dl_tensor = & dlm_tensor_ver.dl_tensor
60- tensor_ptr = dlm_tensor_ver
61- capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
62- else :
63- dlm_tensor = < DLManagedTensor* > (
64- stdlib.malloc(sizeof(DLManagedTensor)))
65- dl_tensor = & dlm_tensor.dl_tensor
66- dlm_tensor.manager_ctx = < void * > buf
67- dlm_tensor.deleter = deleter
68- tensor_ptr = dlm_tensor
69- capsule_name = DLPACK_TENSOR_UNUSED_NAME
70-
71- dl_tensor.data = < void * >< intptr_t> (int (buf.handle))
72- dl_tensor.ndim = 1
73- cdef int64_t* shape_strides = \
74- < int64_t* > stdlib.malloc(sizeof(int64_t) * 2 )
75- shape_strides[0 ] = < int64_t> buf.size
76- shape_strides[1 ] = 1 # redundant
77- dl_tensor.shape = shape_strides
78- dl_tensor.strides = NULL
79- dl_tensor.byte_offset = 0
8064
65+ cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except - 1 :
8166 cdef DLDevice* device = & dl_tensor.device
8267 # buf should be a Buffer instance
8368 if buf.is_device_accessible and not buf.is_host_accessible:
@@ -91,14 +76,103 @@ cpdef object make_py_capsule(object buf, bint versioned):
9176 device.device_id = 0
9277 else : # not buf.is_device_accessible and not buf.is_host_accessible
9378 raise BufferError(" invalid buffer" )
79+ return 0
80+
81+
82+ cdef inline int _setup_dl_tensor_layout(DLTensor* dl_tensor, object buf, StridedLayout layout) except - 1 :
83+ cdef int64_t* shape_strides = NULL
84+ cdef int ndim
85+ if layout is None :
86+ dl_tensor.ndim = 1
87+ shape_strides = < int64_t* > stdlib.malloc(sizeof(int64_t) * 2 )
88+ shape_strides[0 ] = < int64_t> buf.size
89+ shape_strides[1 ] = 1 # redundant
90+ dl_tensor.shape = shape_strides
91+ dl_tensor.strides = NULL
92+ dl_tensor.byte_offset = 0
93+ else :
94+ ndim = layout.ndim
95+ dl_tensor.ndim = ndim
96+ shape_strides = < int64_t* > stdlib.malloc(sizeof(int64_t) * ndim * 2 )
97+ dl_tensor.shape = shape_strides
98+ for i in range (ndim):
99+ shape_strides[i] = layout.base.shape[i]
100+ if layout.base.strides == NULL :
101+ dl_tensor.strides = NULL
102+ else :
103+ dl_tensor.strides = shape_strides + ndim
104+ for i in range (ndim):
105+ dl_tensor.strides[i] = layout.base.strides[i]
106+ dl_tensor.byte_offset = 0
107+ return 0
108+
109+
110+ cdef inline int _setup_dl_tensor_dtype(DLTensor* dl_tensor, object dtype) except - 1 :
111+ cdef DLDataType* dl_dtype = & dl_tensor.dtype
112+ if dtype is None :
113+ dl_dtype.code = < uint8_t> kDLInt
114+ dl_dtype.lanes = < uint16_t> 1
115+ dl_dtype.bits = < uint8_t> 8
116+ return 0
117+ cdef uint8_t code
118+ cdef uint8_t bits
119+ cdef uint16_t lanes
120+ code, bits, lanes = dtype
121+ dl_dtype.code = code
122+ dl_dtype.bits = bits
123+ dl_dtype.lanes = lanes
124+ return 0
125+
126+
127+ cpdef object make_py_capsule(
128+ object buf,
129+ intptr_t data_ptr,
130+ bint versioned,
131+ StridedLayout layout = None ,
132+ object dtype = None ,
133+ ):
134+ cdef DLManagedTensor* dlm_tensor = NULL
135+ cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
136+ cdef DLTensor* dl_tensor = NULL
137+ cdef void * tensor_ptr
138+ cdef const char * capsule_name
94139
95- cdef DLDataType* dtype = & dl_tensor.dtype
96- dtype.code = < uint8_t> kDLInt
97- dtype.lanes = < uint16_t> 1
98- dtype.bits = < uint8_t> 8
99-
100- cpython.Py_INCREF(buf)
101- return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
140+ try :
141+ if versioned:
142+ dlm_tensor_ver = < DLManagedTensorVersioned* > (
143+ stdlib.malloc(sizeof(DLManagedTensorVersioned)))
144+ dlm_tensor_ver.dl_tensor.shape = NULL
145+ dl_tensor = & dlm_tensor_ver.dl_tensor
146+ dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
147+ dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
148+ cpython.Py_INCREF(buf)
149+ dlm_tensor_ver.manager_ctx = < void * > buf
150+ dlm_tensor_ver.deleter = versioned_deleter
151+ dlm_tensor_ver.flags = 0
152+ tensor_ptr = dlm_tensor_ver
153+ capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
154+ else :
155+ dlm_tensor = < DLManagedTensor* > (
156+ stdlib.malloc(sizeof(DLManagedTensor)))
157+ dlm_tensor.dl_tensor.shape = NULL
158+ dl_tensor = & dlm_tensor.dl_tensor
159+ cpython.Py_INCREF(buf)
160+ dlm_tensor.manager_ctx = < void * > buf
161+ dlm_tensor.deleter = deleter
162+ tensor_ptr = dlm_tensor
163+ capsule_name = DLPACK_TENSOR_UNUSED_NAME
164+
165+ dl_tensor.data = < void * > data_ptr
166+
167+ _setup_dl_tensor_device(dl_tensor, buf)
168+ _setup_dl_tensor_layout(dl_tensor, buf, layout)
169+ _setup_dl_tensor_dtype(dl_tensor, dtype)
170+
171+ return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
172+ except :
173+ cleanup(dlm_tensor)
174+ cleanup_versioned(dlm_tensor_ver)
175+ raise
102176
103177
104178class DLDeviceType (IntEnum ):
0 commit comments