Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ keywords = ["dlpack", "deep-learning", "machine-learning"]
[dependencies]
ndarray = { version = "0.16", optional = true }
pyo3 = { version = "0.26", optional = true }
# 2.3 has MSRV 1.7.0
# 2.5 has MSRV 1.81
half = { version = "<2.5", optional = true }

[features]
ndarray = ["dep:ndarray"]
pyo3 = ["dep:pyo3"]
half = ["dep:half"]
Comment on lines 27 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this is what prevent automatic feature activation from metatensor. Let me try locally (and I'll send a separate PR in this case)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, apparently this is not required (https://doc.rust-lang.org/cargo/reference/features.html#optional-dependencies) but we have to enable them when using them anyway. Am I remembering how stuff worked previously?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I haven't done enough package management in Rust to recall :D

85 changes: 85 additions & 0 deletions src/data_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ impl_dlpack_pointer_cast!(DLDataTypeCode::kDLInt, i8, i16, i32, i64,);
impl_dlpack_pointer_cast!(DLDataTypeCode::kDLFloat, f32, f64,);
impl_dlpack_pointer_cast!(DLDataTypeCode::kDLBool, bool,);

#[cfg(feature = "half")]
impl_dlpack_pointer_cast!(DLDataTypeCode::kDLFloat, half::f16,);

macro_rules! impl_dlpack_pointer_cast_complex {
($dlpack_code: expr, $($type: ty),+, ) => {
$(impl DLPackPointerCast for [$type; 2] {
fn dlpack_ptr_cast(ptr: *mut std::os::raw::c_void, data_type: DLDataType) -> Result<*mut Self, CastError> {
// For complex types, the DLPack bits represent the total size (real + imag).
// So [f32; 2] corresponds to kDLComplex with 64 bits.
let expected_bits = 8 * ::std::mem::size_of::<Self>();

if (data_type.bits as usize) != expected_bits || data_type.code != $dlpack_code {
return Err(CastError::WrongType { dl_type: data_type, rust_type: stringify!([$type; 2])});
}

if data_type.lanes != 1 {
return Err(CastError::Lanes { expected: 1, given: data_type.lanes as usize });
}

let ptr = ptr.cast::<Self>();
if !ptr_is_aligned(ptr) {
return Err(CastError::BadAlignment {
ptr: ptr as usize,
align: ::std::mem::align_of::<Self>(),
rust_type: stringify!([$type; 2]),
});
}

return Ok(ptr);
}
})+
};
}

impl_dlpack_pointer_cast_complex!(DLDataTypeCode::kDLComplex, f32, f64,);


/// Trait to get the DLPack datatype corresponding to a Rust datatype
#[allow(dead_code)]
Expand All @@ -101,6 +137,26 @@ impl_get_dlpack_data_type!(DLDataTypeCode::kDLInt, i8, i16, i32, i64,);
impl_get_dlpack_data_type!(DLDataTypeCode::kDLFloat, f32, f64,);
impl_get_dlpack_data_type!(DLDataTypeCode::kDLBool, bool,);

#[cfg(feature = "half")]
impl_get_dlpack_data_type!(DLDataTypeCode::kDLFloat, half::f16,);

macro_rules! impl_get_dlpack_data_type_complex {
($dlpack_code: expr, $($type: ty),+, ) => {
$(impl GetDLPackDataType for [$type; 2] {
fn get_dlpack_data_type() -> DLDataType {
DLDataType {
code: $dlpack_code,
// size_of::<[f32; 2]> is 8 bytes, so 8 * 8 = 64 bits
bits: (8 * std::mem::size_of::<Self>()).try_into().expect("failed to convert type size to u8"),
lanes: 1,
}
}
})+
};
}

impl_get_dlpack_data_type_complex!(DLDataTypeCode::kDLComplex, f32, f64,);

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -175,4 +231,33 @@ mod tests {
panic!("Expected a BadAlignment error");
}
}

#[test]
fn test_complex_mapping_consistency() {
use std::mem::size_of;

// complex64 check
assert_eq!(size_of::<[f32; 2]>() * 8, 64);
assert_eq!(<[f32; 2]>::get_dlpack_data_type().bits, 64);

// complex128 check
assert_eq!(size_of::<[f64; 2]>() * 8, 128);
assert_eq!(<[f64; 2]>::get_dlpack_data_type().bits, 128);
}

#[test]
#[cfg(feature = "half")]
fn test_half_precision() {
// Test that half::f16 correctly maps to kDLFloat with 16 bits
let dtype = half::f16::get_dlpack_data_type();
assert_eq!(dtype.code, DLDataTypeCode::kDLFloat);
assert_eq!(dtype.bits, 16);

let value = half::f16::from_f32(1.5);
let mut mock_data = value;
let ptr = &mut mock_data as *mut half::f16 as *mut std::os::raw::c_void;

let result = half::f16::dlpack_ptr_cast(ptr, dtype);
assert!(result.is_ok());
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mod data_types;
pub use self::data_types::GetDLPackDataType;

pub use self::data_types::CastError;
use self::data_types::DLPackPointerCast;
pub use self::data_types::DLPackPointerCast;

impl sys::DLPackVersion {
/// Returns the DLPack version supported by this library.
Expand Down
Loading