Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


#![allow(clippy::needless_return, clippy::redundant_field_names)]
#![forbid(clippy::as_ptr_cast_mut, clippy::ptr_cast_constness)]

use std::{ffi::c_void, ptr::NonNull};

Expand Down Expand Up @@ -350,7 +351,7 @@ impl DLPackTensor {
/// Consumes the `DLPackTensor`, returning the underlying raw pointer.
///
/// # Safety
///
///
/// The caller is responsible for managing the memory and calling the deleter
/// when the tensor is no longer needed.
pub fn into_raw(self) -> NonNull<sys::DLManagedTensorVersioned>{
Expand Down
33 changes: 17 additions & 16 deletions src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
//! let tensor_ref: DLPackTensorRef = (&array).try_into().unwrap();
//! ```

use std::ffi::c_void;
use ndarray::{Array, ArcArray, Dimension, ShapeBuilder};

use crate::data_types::{CastError, DLPackPointerCast, GetDLPackDataType};
Expand Down Expand Up @@ -376,7 +375,7 @@ struct ManagerContext<T> {

unsafe extern "C" fn deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) {
// Reconstruct the box and drop it, freeing the memory.
let ctx = (*manager).manager_ctx as *mut ManagerContext<T>;
let ctx = (*manager).manager_ctx.cast::<ManagerContext<T>>();
let _ = Box::from_raw(ctx);
}

Expand All @@ -398,7 +397,11 @@ where
});

let dl_tensor = sys::DLTensor {
data: ctx._array.as_ptr() as *mut _,
// Casting to a mut pointer is not necessarily safe, but is required
// by DLPack. The data can be mutated through this pointer, we
// should try to find a way to make this work in Rust type system in
// the future.
data: ctx._array.as_ptr().cast_mut().cast(),
device: sys::DLDevice {
device_type: sys::DLDeviceType::kDLCPU,
device_id: 0,
Expand All @@ -412,7 +415,7 @@ where

let managed_tensor = sys::DLManagedTensorVersioned {
version: sys::DLPackVersion::current(),
manager_ctx: Box::into_raw(ctx) as *mut _,
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(deleter_fn::<Array<T, D>>),
flags: 0,
dl_tensor,
Expand Down Expand Up @@ -446,26 +449,24 @@ where
strides,
});

let data_ptr = ctx._array.as_ptr() as *mut c_void;
let shape_ptr = ctx.shape.as_mut_ptr();
let strides_ptr = ctx.strides.as_mut_ptr();

let dl_tensor = sys::DLTensor {
data: data_ptr,
// Same as above, casting to a mut pointer is not necessarily safe.
data: ctx._array.as_ptr().cast_mut().cast(),
device: sys::DLDevice {
device_type: sys::DLDeviceType::kDLCPU,
device_id: 0,
},
ndim,
dtype: T::get_dlpack_data_type(),
shape: shape_ptr,
strides: strides_ptr,
shape: ctx.shape.as_mut_ptr(),
strides: ctx.strides.as_mut_ptr(),
byte_offset: 0,
};

let managed_tensor = sys::DLManagedTensorVersioned {
version: sys::DLPackVersion::current(),
manager_ctx: Box::into_raw(ctx) as *mut _,
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(deleter_fn::<ArcArray<T, D>>),
flags: 0,
dl_tensor,
Expand All @@ -486,12 +487,12 @@ mod tests {

#[test]
fn test_dlpack_to_ndarray() {
let mut data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut shape = vec![2i64, 3];
let mut strides = vec![3i64, 1];

let dl_tensor = DLTensor {
data: data.as_mut_ptr() as *mut _,
data: data.as_ptr().cast_mut().cast(),
device: DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
Expand All @@ -518,7 +519,7 @@ mod tests {
let mut strides = vec![1i64, 2];

let dl_tensor = DLTensor {
data: data.as_mut_ptr() as *mut _,
data: data.as_mut_ptr().cast(),
device: DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
Expand All @@ -544,7 +545,7 @@ mod tests {
let mut shape = vec![1i64];

let dl_tensor = DLTensor {
data: data.as_mut_ptr() as *mut _,
data: data.as_mut_ptr().cast(),
device: DLDevice {
device_type: DLDeviceType::kDLCUDA,
device_id: 0,
Expand Down Expand Up @@ -587,7 +588,7 @@ mod tests {
let mut strides = vec![3i64, 1];

let dl_tensor = DLTensor {
data: data.as_mut_ptr() as *mut _,
data: data.as_mut_ptr().cast(),
device: DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
Expand Down
81 changes: 41 additions & 40 deletions src/sys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ pub struct DLManagedTensorVersioned {
///----------------------------------------------------------------------
/// DLPack `__dlpack_c_exchange_api__` fast exchange protocol definitions
///----------------------------------------------------------------------

///
/// Request a producer library to create a new tensor.
///
/// Create a new `DLManagedTensorVersioned` within the context of the producer
Expand All @@ -384,17 +384,17 @@ pub struct DLManagedTensorVersioned {
/// # Arguments
///
/// * `prototype` - The prototype DLTensor. Only the dtype, ndim, shape,
/// and device fields are used.
/// and device fields are used.
/// * `out` - The output DLManagedTensorVersioned.
/// * `error_ctx` - Context for `SetError`.
/// * `SetError` - The function to set the error.
///
/// # Returns
///
///
/// The owning DLManagedTensorVersioned* or NULL on failure.
/// SetError is called exactly when NULL is returned (the implementer
/// must ensure this).
///
///
/// NOTE: - As a C function, must not thrown C++ exceptions.
/// - Error propagation via SetError to avoid any direct need
/// of Python API. Due to this `SetError` may have to ensure the GIL is
Expand All @@ -419,17 +419,17 @@ pub type DLPackManagedTensorAllocator = Option<unsafe extern "C" fn(
/// This function is exposed by the framework through the DLPackExchangeAPI.
///
/// # Arguments
///
///
/// * `py_object` - The Python object to convert. Must have the same type
/// as the one the `DLPackExchangeAPI` was discovered from.
/// as the one the `DLPackExchangeAPI` was discovered from.
/// * `out` - The output DLManagedTensorVersioned.
///
///
/// # Returns
///
///
/// The owning DLManagedTensorVersioned* or NULL on failure with a
/// Python exception set. If the data cannot be described using DLPack
/// this should be a BufferError if possible.
///
///
/// NOTE: - As a C function, must not thrown C++ exceptions.
///
/// See also:
Expand All @@ -455,13 +455,13 @@ pub type DLPackManagedTensorFromPyObjectNoSync = Option<unsafe extern "C" fn(
/// This function is exposed by the framework through the DLPackExchangeAPI.
///
/// # Arguments
///
///
/// * `py_object` - The Python object to convert. Must have the same type
/// as the one the `DLPackExchangeAPI` was discovered from.
/// as the one the `DLPackExchangeAPI` was discovered from.
/// * `out` - The output DLTensor, whose space is pre-allocated on stack.
///
/// # Returns
///
///
/// 0 on success, -1 on failure with a Python exception set.
///
/// NOTE: - As a C function, must not thrown C++ exceptions.
Expand All @@ -485,15 +485,15 @@ pub type DLPackDLTensorFromPyObjectNoSync = Option<unsafe extern "C" fn(
/// always set out_current_stream[0] to NULL.
///
/// # Arguments
///
///
/// * `device_type` - The device type.
/// * `device_id` - The device id.
/// * `out_current_stream` - The output current work stream.
///
/// # Returns
///
///
/// 0 on success, -1 on failure with a Python exception set.
///
///
/// NOTE: - As a C function, must not thrown C++ exceptions.
///
/// See also:
Expand All @@ -514,15 +514,15 @@ pub type DLPackCurrentWorkStream = Option<unsafe extern "C" fn(
/// This function is exposed by the framework through the DLPackExchangeAPI.
///
/// # Arguments
///
///
/// * `tensor` - The DLManagedTensorVersioned to convert the ownership of the
/// tensor is stolen.
/// tensor is stolen.
/// * `out_py_object` - The output Python object.
///
///
/// # Returns
///
///
/// 0 on success, -1 on failure with a Python exception set.
///
///
/// See also:
/// DLPackExchangeAPI
pub type DLPackManagedTensorToPyObjectNoSync = Option<unsafe extern "C" fn(
Expand Down Expand Up @@ -552,9 +552,8 @@ pub struct DLPackExchangeAPIHeader {
///
/// Additionally to `__dlpack__()` we define a C function table sharable by
///
/// Python implementations via `__dlpack_c_exchange_api__`.
/// This attribute must be set on the type as a Python PyCapsule
/// with name "dlpack_exchange_api".
/// Python implementations via `__dlpack_c_exchange_api__`. This attribute must
/// be set on the type as a Python PyCapsule with name "dlpack_exchange_api".
///
/// A consumer library may use a pattern such as:
///
Expand Down Expand Up @@ -600,28 +599,30 @@ pub struct DLPackExchangeAPIHeader {
/// Guidelines for leveraging DLPackExchangeAPI:
///
/// There are generally two kinds of consumer needs for DLPack exchange:
/// - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
/// with the data from x, y, z. The consumer is also expected to run the kernel with the same
/// stream context as the producer. For example, when x, y, z is torch.Tensor,
/// consumer should query exchange_api->current_work_stream to get the
/// current stream and launch the kernel with the same stream.
/// This setup is necessary for no synchronization in kernel launch and maximum compatibility
/// with CUDA graph capture in the producer.
/// This is the desirable behavior for library extension support for frameworks like PyTorch.
/// - N0: library support, where consumer.kernel(x, y, z) would like to run a
/// kernel with the data from x, y, z. The consumer is also expected to run
/// the kernel with the same stream context as the producer. For example, when
/// x, y, z is torch.Tensor, consumer should query
/// exchange_api->current_work_stream to get the current stream and launch the
/// kernel with the same stream. This setup is necessary for no
/// synchronization in kernel launch and maximum compatibility with CUDA graph
/// capture in the producer. This is the desirable behavior for library
/// extension support for frameworks like PyTorch.
/// - N1: data ingestion and retention
///
/// Note that obj.__dlpack__() API should provide useful ways for N1.
/// The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0
/// with the support of the function pointer current_work_stream.
/// Note that obj.__dlpack__() API should provide useful ways for N1. The
/// primary focus of the current DLPackExchangeAPI is to enable faster exchange
/// N0 with the support of the function pointer current_work_stream.
///
/// Array/Tensor libraries should statically create and initialize this structure
/// then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
/// The DLPackExchangeAPI* must stay alive throughout the lifetime of the process.
/// Array/Tensor libraries should statically create and initialize this
/// structure then return a pointer to DLPackExchangeAPI as an int value in
/// Tensor/Array. The DLPackExchangeAPI* must stay alive throughout the lifetime
/// of the process.
///
/// One simple way to do so is to create a static instance of DLPackExchangeAPI
/// within the framework and return a pointer to it. The following code
/// shows an example to do so in C++. It should also be reasonably easy
/// to do so in other languages.
/// within the framework and return a pointer to it. The following code shows an
/// example to do so in C++. It should also be reasonably easy to do so in other
/// languages.
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DLPackExchangeAPI {
Expand Down
Loading