diff --git a/Cargo.toml b/Cargo.toml index aad1298..270410f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,11 @@ keywords = ["dlpack", "deep-learning", "machine-learning"] [dependencies] ndarray = { version = "0.16", optional = true } pyo3 = { version = "0.26", optional = true } +# 2.3 has MSRV 1.7.0 +# 2.5 has MSRV 1.81 +half = { version = "<2.5", optional = true } [features] ndarray = ["dep:ndarray"] pyo3 = ["dep:pyo3"] +half = ["dep:half"] diff --git a/src/data_types.rs b/src/data_types.rs index b70955f..283cb6e 100644 --- a/src/data_types.rs +++ b/src/data_types.rs @@ -75,6 +75,42 @@ impl_dlpack_pointer_cast!(DLDataTypeCode::kDLInt, i8, i16, i32, i64,); impl_dlpack_pointer_cast!(DLDataTypeCode::kDLFloat, f32, f64,); impl_dlpack_pointer_cast!(DLDataTypeCode::kDLBool, bool,); +#[cfg(feature = "half")] +impl_dlpack_pointer_cast!(DLDataTypeCode::kDLFloat, half::f16,); + +macro_rules! impl_dlpack_pointer_cast_complex { + ($dlpack_code: expr, $($type: ty),+, ) => { + $(impl DLPackPointerCast for [$type; 2] { + fn dlpack_ptr_cast(ptr: *mut std::os::raw::c_void, data_type: DLDataType) -> Result<*mut Self, CastError> { + // For complex types, the DLPack bits represent the total size (real + imag). + // So [f32; 2] corresponds to kDLComplex with 64 bits. + let expected_bits = 8 * ::std::mem::size_of::(); + + if (data_type.bits as usize) != expected_bits || data_type.code != $dlpack_code { + return Err(CastError::WrongType { dl_type: data_type, rust_type: stringify!([$type; 2])}); + } + + if data_type.lanes != 1 { + return Err(CastError::Lanes { expected: 1, given: data_type.lanes as usize }); + } + + let ptr = ptr.cast::(); + if !ptr_is_aligned(ptr) { + return Err(CastError::BadAlignment { + ptr: ptr as usize, + align: ::std::mem::align_of::(), + rust_type: stringify!([$type; 2]), + }); + } + + return Ok(ptr); + } + })+ + }; +} + +impl_dlpack_pointer_cast_complex!(DLDataTypeCode::kDLComplex, f32, f64,); + /// Trait to get the DLPack datatype corresponding to a Rust datatype #[allow(dead_code)] @@ -101,6 +137,26 @@ impl_get_dlpack_data_type!(DLDataTypeCode::kDLInt, i8, i16, i32, i64,); impl_get_dlpack_data_type!(DLDataTypeCode::kDLFloat, f32, f64,); impl_get_dlpack_data_type!(DLDataTypeCode::kDLBool, bool,); +#[cfg(feature = "half")] +impl_get_dlpack_data_type!(DLDataTypeCode::kDLFloat, half::f16,); + +macro_rules! impl_get_dlpack_data_type_complex { + ($dlpack_code: expr, $($type: ty),+, ) => { + $(impl GetDLPackDataType for [$type; 2] { + fn get_dlpack_data_type() -> DLDataType { + DLDataType { + code: $dlpack_code, + // size_of::<[f32; 2]> is 8 bytes, so 8 * 8 = 64 bits + bits: (8 * std::mem::size_of::()).try_into().expect("failed to convert type size to u8"), + lanes: 1, + } + } + })+ + }; +} + +impl_get_dlpack_data_type_complex!(DLDataTypeCode::kDLComplex, f32, f64,); + #[cfg(test)] mod tests { use super::*; @@ -175,4 +231,33 @@ mod tests { panic!("Expected a BadAlignment error"); } } + + #[test] + fn test_complex_mapping_consistency() { + use std::mem::size_of; + + // complex64 check + assert_eq!(size_of::<[f32; 2]>() * 8, 64); + assert_eq!(<[f32; 2]>::get_dlpack_data_type().bits, 64); + + // complex128 check + assert_eq!(size_of::<[f64; 2]>() * 8, 128); + assert_eq!(<[f64; 2]>::get_dlpack_data_type().bits, 128); + } + + #[test] + #[cfg(feature = "half")] + fn test_half_precision() { + // Test that half::f16 correctly maps to kDLFloat with 16 bits + let dtype = half::f16::get_dlpack_data_type(); + assert_eq!(dtype.code, DLDataTypeCode::kDLFloat); + assert_eq!(dtype.bits, 16); + + let value = half::f16::from_f32(1.5); + let mut mock_data = value; + let ptr = &mut mock_data as *mut half::f16 as *mut std::os::raw::c_void; + + let result = half::f16::dlpack_ptr_cast(ptr, dtype); + assert!(result.is_ok()); + } } diff --git a/src/lib.rs b/src/lib.rs index d61654a..1fb7586 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,7 @@ mod data_types; pub use self::data_types::GetDLPackDataType; pub use self::data_types::CastError; -use self::data_types::DLPackPointerCast; +pub use self::data_types::DLPackPointerCast; impl sys::DLPackVersion { /// Returns the DLPack version supported by this library.