Skip to content

Commit dc27268

Browse files
committed
Revert dlpack changes
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent 0f3bc89 commit dc27268

File tree

2 files changed

+43
-117
lines changed

2 files changed

+43
-117
lines changed

cuda_core/cuda/core/experimental/_dlpack.pyx

Lines changed: 42 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from enum import IntEnum
6-
from ._layout cimport StridedLayout
76

87

98
cdef void pycapsule_deleter(object capsule) noexcept:
@@ -34,16 +33,6 @@ cdef void deleter(DLManagedTensor* tensor) noexcept with gil:
3433
stdlib.free(tensor)
3534

3635

37-
cdef void cleanup(DLManagedTensor* tensor) noexcept with gil:
38-
if tensor:
39-
if tensor.dl_tensor.shape:
40-
stdlib.free(tensor.dl_tensor.shape)
41-
if tensor.manager_ctx:
42-
cpython.Py_DECREF(<object>tensor.manager_ctx)
43-
tensor.manager_ctx = NULL
44-
stdlib.free(tensor)
45-
46-
4736
cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
4837
stdlib.free(tensor.dl_tensor.shape)
4938
if tensor.manager_ctx:
@@ -52,17 +41,43 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
5241
stdlib.free(tensor)
5342

5443

55-
cdef void cleanup_versioned(DLManagedTensorVersioned* tensor) noexcept with gil:
56-
if tensor:
57-
if tensor.dl_tensor.shape:
58-
stdlib.free(tensor.dl_tensor.shape)
59-
if tensor.manager_ctx:
60-
cpython.Py_DECREF(<object>tensor.manager_ctx)
61-
tensor.manager_ctx = NULL
62-
stdlib.free(tensor)
44+
cpdef object make_py_capsule(object buf, bint versioned):
45+
cdef DLManagedTensor* dlm_tensor
46+
cdef DLManagedTensorVersioned* dlm_tensor_ver
47+
cdef DLTensor* dl_tensor
48+
cdef void* tensor_ptr
49+
cdef const char* capsule_name
6350

51+
if versioned:
52+
dlm_tensor_ver = <DLManagedTensorVersioned*>(
53+
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
54+
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
55+
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
56+
dlm_tensor_ver.manager_ctx = <void*>buf
57+
dlm_tensor_ver.deleter = versioned_deleter
58+
dlm_tensor_ver.flags = 0
59+
dl_tensor = &dlm_tensor_ver.dl_tensor
60+
tensor_ptr = dlm_tensor_ver
61+
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
62+
else:
63+
dlm_tensor = <DLManagedTensor*>(
64+
stdlib.malloc(sizeof(DLManagedTensor)))
65+
dl_tensor = &dlm_tensor.dl_tensor
66+
dlm_tensor.manager_ctx = <void*>buf
67+
dlm_tensor.deleter = deleter
68+
tensor_ptr = dlm_tensor
69+
capsule_name = DLPACK_TENSOR_UNUSED_NAME
70+
71+
dl_tensor.data = <void*><intptr_t>(int(buf.handle))
72+
dl_tensor.ndim = 1
73+
cdef int64_t* shape_strides = \
74+
<int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
75+
shape_strides[0] = <int64_t>buf.size
76+
shape_strides[1] = 1 # redundant
77+
dl_tensor.shape = shape_strides
78+
dl_tensor.strides = NULL
79+
dl_tensor.byte_offset = 0
6480

65-
cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except -1:
6681
cdef DLDevice* device = &dl_tensor.device
6782
# buf should be a Buffer instance
6883
if buf.is_device_accessible and not buf.is_host_accessible:
@@ -76,103 +91,14 @@ cdef inline int _setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except
7691
device.device_id = 0
7792
else: # not buf.is_device_accessible and not buf.is_host_accessible
7893
raise BufferError("invalid buffer")
79-
return 0
80-
81-
82-
cdef inline int _setup_dl_tensor_layout(DLTensor* dl_tensor, object buf, StridedLayout layout) except -1:
83-
cdef int64_t* shape_strides = NULL
84-
cdef int ndim
85-
if layout is None:
86-
dl_tensor.ndim = 1
87-
shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
88-
shape_strides[0] = <int64_t>buf.size
89-
shape_strides[1] = 1 # redundant
90-
dl_tensor.shape = shape_strides
91-
dl_tensor.strides = NULL
92-
dl_tensor.byte_offset = 0
93-
else:
94-
ndim = layout.ndim
95-
dl_tensor.ndim = ndim
96-
shape_strides = <int64_t*>stdlib.malloc(sizeof(int64_t) * ndim * 2)
97-
dl_tensor.shape = shape_strides
98-
for i in range(ndim):
99-
shape_strides[i] = layout.base.shape[i]
100-
if layout.base.strides == NULL:
101-
dl_tensor.strides = NULL
102-
else:
103-
dl_tensor.strides = shape_strides + ndim
104-
for i in range(ndim):
105-
dl_tensor.strides[i] = layout.base.strides[i]
106-
dl_tensor.byte_offset = 0
107-
return 0
108-
109-
110-
cdef inline int _setup_dl_tensor_dtype(DLTensor* dl_tensor, object dtype) except -1:
111-
cdef DLDataType* dl_dtype = &dl_tensor.dtype
112-
if dtype is None:
113-
dl_dtype.code = <uint8_t>kDLInt
114-
dl_dtype.lanes = <uint16_t>1
115-
dl_dtype.bits = <uint8_t>8
116-
return 0
117-
cdef uint8_t code
118-
cdef uint8_t bits
119-
cdef uint16_t lanes
120-
code, bits, lanes = dtype
121-
dl_dtype.code = code
122-
dl_dtype.bits = bits
123-
dl_dtype.lanes = lanes
124-
return 0
125-
126-
127-
cpdef object make_py_capsule(
128-
object buf,
129-
bint versioned,
130-
intptr_t data_ptr,
131-
StridedLayout layout=None,
132-
object dtype=None,
133-
):
134-
cdef DLManagedTensor* dlm_tensor = NULL
135-
cdef DLManagedTensorVersioned* dlm_tensor_ver = NULL
136-
cdef DLTensor* dl_tensor = NULL
137-
cdef void* tensor_ptr
138-
cdef const char* capsule_name
13994

140-
try:
141-
if versioned:
142-
dlm_tensor_ver = <DLManagedTensorVersioned*>(
143-
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
144-
dlm_tensor_ver.dl_tensor.shape = NULL
145-
dl_tensor = &dlm_tensor_ver.dl_tensor
146-
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
147-
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
148-
cpython.Py_INCREF(buf)
149-
dlm_tensor_ver.manager_ctx = <void*>buf
150-
dlm_tensor_ver.deleter = versioned_deleter
151-
dlm_tensor_ver.flags = 0
152-
tensor_ptr = dlm_tensor_ver
153-
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
154-
else:
155-
dlm_tensor = <DLManagedTensor*>(
156-
stdlib.malloc(sizeof(DLManagedTensor)))
157-
dlm_tensor.dl_tensor.shape = NULL
158-
dl_tensor = &dlm_tensor.dl_tensor
159-
cpython.Py_INCREF(buf)
160-
dlm_tensor.manager_ctx = <void*>buf
161-
dlm_tensor.deleter = deleter
162-
tensor_ptr = dlm_tensor
163-
capsule_name = DLPACK_TENSOR_UNUSED_NAME
164-
165-
dl_tensor.data = <void*>data_ptr
166-
167-
_setup_dl_tensor_device(dl_tensor, buf)
168-
_setup_dl_tensor_layout(dl_tensor, buf, layout)
169-
_setup_dl_tensor_dtype(dl_tensor, dtype)
170-
171-
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
172-
except:
173-
cleanup(dlm_tensor)
174-
cleanup_versioned(dlm_tensor_ver)
175-
raise
95+
cdef DLDataType* dtype = &dl_tensor.dtype
96+
dtype.code = <uint8_t>kDLInt
97+
dtype.lanes = <uint16_t>1
98+
dtype.bits = <uint8_t>8
99+
100+
cpython.Py_INCREF(buf)
101+
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
176102

177103

178104
class DLDeviceType(IntEnum):

cuda_core/cuda/core/experimental/_memory/_buffer.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ cdef class Buffer:
203203
if not isinstance(max_version, tuple) or len(max_version) != 2:
204204
raise BufferError(f"Expected max_version tuple[int, int], got {max_version}")
205205
versioned = max_version >= (1, 0)
206-
capsule = make_py_capsule(self, versioned, int(self.handle))
206+
capsule = make_py_capsule(self, versioned)
207207
return capsule
208208

209209
def __dlpack_device__(self) -> tuple[int, int]:

0 commit comments

Comments
 (0)