Skip to content

Commit df6c909

Browse files
committed
Support wrapping ptr in Buffer, create SMV from buffer and layout, dlpack export
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent 68aaee8 commit df6c909

File tree

6 files changed

+592
-167
lines changed

6 files changed

+592
-167
lines changed

cuda_core/cuda/core/experimental/_dlpack.pyx

Lines changed: 116 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from enum import IntEnum
6+
from ._layout cimport StridedLayout
67

78

89
cdef void pycapsule_deleter(object capsule) noexcept:
@@ -33,6 +34,16 @@ cdef void deleter(DLManagedTensor* tensor) noexcept with gil:
3334
stdlib.free(tensor)
3435

3536

37+
cdef void cleanup(DLManagedTensor* tensor) noexcept with gil:
38+
if tensor:
39+
if tensor.dl_tensor.shape:
40+
stdlib.free(tensor.dl_tensor.shape)
41+
if tensor.manager_ctx:
42+
cpython.Py_DECREF(<object>tensor.manager_ctx)
43+
tensor.manager_ctx = NULL
44+
stdlib.free(tensor)
45+
46+
3647
cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
3748
stdlib.free(tensor.dl_tensor.shape)
3849
if tensor.manager_ctx:
@@ -41,43 +52,17 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
4152
stdlib.free(tensor)
4253

4354

44-
cpdef object make_py_capsule(object buf, bint versioned):
45-
cdef DLManagedTensor* dlm_tensor
46-
cdef DLManagedTensorVersioned* dlm_tensor_ver
47-
cdef DLTensor* dl_tensor
48-
cdef void* tensor_ptr
49-
cdef const char* capsule_name
55+
cdef void cleanup_versioned(DLManagedTensorVersioned* tensor) noexcept with gil:
56+
if tensor:
57+
if tensor.dl_tensor.shape:
58+
stdlib.free(tensor.dl_tensor.shape)
59+
if tensor.manager_ctx:
60+
cpython.Py_DECREF(<object>tensor.manager_ctx)
61+
tensor.manager_ctx = NULL
62+
stdlib.free(tensor)
5063

51-
if versioned:
52-
dlm_tensor_ver = <DLManagedTensorVersioned*>(
53-
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
54-
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
55-
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
56-
dlm_tensor_ver.manager_ctx = <void*>buf
57-
dlm_tensor_ver.deleter = versioned_deleter
58-
dlm_tensor_ver.flags = 0
59-
dl_tensor = &dlm_tensor_ver.dl_tensor
60-
tensor_ptr = dlm_tensor_ver
61-
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
62-
else:
63-
dlm_tensor = <DLManagedTensor*>(
64-
stdlib.malloc(sizeof(DLManagedTensor)))
65-
dl_tensor = &dlm_tensor.dl_tensor
66-
dlm_tensor.manager_ctx = <void*>buf
67-
dlm_tensor.deleter = deleter
68-
tensor_ptr = dlm_tensor
69-
capsule_name = DLPACK_TENSOR_UNUSED_NAME
70-
71-
dl_tensor.data = <void*><intptr_t>(int(buf.handle))
72-
dl_tensor.ndim = 1
73-
cdef int64_t* shape_strides = \
74-
<int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
75-
shape_strides[0] = <int64_t>buf.size
76-
shape_strides[1] = 1 # redundant
77-
dl_tensor.shape = shape_strides
78-
dl_tensor.strides = NULL
79-
dl_tensor.byte_offset = 0
8064

65+
cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except -1:
8166
cdef DLDevice* device = &dl_tensor.device
8267
# buf should be a Buffer instance
8368
if buf.is_device_accessible and not buf.is_host_accessible:
@@ -91,14 +76,103 @@ cpdef object make_py_capsule(object buf, bint versioned):
9176
device.device_id = 0
9277
else: # not buf.is_device_accessible and not buf.is_host_accessible
9378
raise BufferError("invalid buffer")
79+
return 0
80+
81+
82+
cdef inline int _setup_dl_tensor_layout(DLTensor* dl_tensor, object buf, StridedLayout layout) except -1:
83+
cdef int64_t* shape_strides = NULL
84+
cdef int ndim
85+
if layout is None:
86+
dl_tensor.ndim = 1
87+
shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
88+
shape_strides[0] = <int64_t>buf.size
89+
shape_strides[1] = 1 # redundant
90+
dl_tensor.shape = shape_strides
91+
dl_tensor.strides = NULL
92+
dl_tensor.byte_offset = 0
93+
else:
94+
ndim = layout.ndim
95+
dl_tensor.ndim = ndim
96+
shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * ndim * 2)
97+
dl_tensor.shape = shape_strides
98+
for i in range(ndim):
99+
shape_strides[i] = layout.base.shape[i]
100+
if layout.base.strides == NULL:
101+
dl_tensor.strides = NULL
102+
else:
103+
dl_tensor.strides = shape_strides + ndim
104+
for i in range(ndim):
105+
dl_tensor.strides[i] = layout.base.strides[i]
106+
dl_tensor.byte_offset = 0
107+
return 0
108+
109+
110+
cdef inline int _setup_dl_tensor_dtype(DLTensor* dl_tensor, object dtype) except -1:
111+
cdef DLDataType* dl_dtype = &dl_tensor.dtype
112+
if dtype is None:
113+
dl_dtype.code = <uint8_t>kDLInt
114+
dl_dtype.lanes = <uint16_t>1
115+
dl_dtype.bits = <uint8_t>8
116+
return 0
117+
cdef uint8_t code
118+
cdef uint8_t bits
119+
cdef uint16_t lanes
120+
code, bits, lanes = dtype
121+
dl_dtype.code = code
122+
dl_dtype.bits = bits
123+
dl_dtype.lanes = lanes
124+
return 0
125+
126+
127+
cpdef object make_py_capsule(
128+
object buf,
129+
intptr_t data_ptr,
130+
bint versioned,
131+
StridedLayout layout=None,
132+
object dtype=None,
133+
):
134+
cdef DLManagedTensor* dlm_tensor = NULL
135+
cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
136+
cdef DLTensor* dl_tensor = NULL
137+
cdef void* tensor_ptr
138+
cdef const char* capsule_name
94139

95-
cdef DLDataType* dtype = &dl_tensor.dtype
96-
dtype.code = <uint8_t>kDLInt
97-
dtype.lanes = <uint16_t>1
98-
dtype.bits = <uint8_t>8
99-
100-
cpython.Py_INCREF(buf)
101-
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
140+
try:
141+
if versioned:
142+
dlm_tensor_ver = <DLManagedTensorVersioned*>(
143+
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
144+
dlm_tensor_ver.dl_tensor.shape = NULL
145+
dl_tensor = &dlm_tensor_ver.dl_tensor
146+
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
147+
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
148+
cpython.Py_INCREF(buf)
149+
dlm_tensor_ver.manager_ctx = <void*>buf
150+
dlm_tensor_ver.deleter = versioned_deleter
151+
dlm_tensor_ver.flags = 0
152+
tensor_ptr = dlm_tensor_ver
153+
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
154+
else:
155+
dlm_tensor = <DLManagedTensor*>(
156+
stdlib.malloc(sizeof(DLManagedTensor)))
157+
dlm_tensor.dl_tensor.shape = NULL
158+
dl_tensor = &dlm_tensor.dl_tensor
159+
cpython.Py_INCREF(buf)
160+
dlm_tensor.manager_ctx = <void*>buf
161+
dlm_tensor.deleter = deleter
162+
tensor_ptr = dlm_tensor
163+
capsule_name = DLPACK_TENSOR_UNUSED_NAME
164+
165+
dl_tensor.data = <void*>data_ptr
166+
167+
_setup_dl_tensor_device(dl_tensor, buf)
168+
_setup_dl_tensor_layout(dl_tensor, buf, layout)
169+
_setup_dl_tensor_dtype(dl_tensor, dtype)
170+
171+
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
172+
except:
173+
cleanup(dlm_tensor)
174+
cleanup_versioned(dlm_tensor_ver)
175+
raise
102176

103177

104178
class DLDeviceType(IntEnum):

0 commit comments

Comments
 (0)