From 129bc06e8127cf4907b10cad7cd27b094fbd66fe Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 6 Feb 2026 20:05:29 +0100 Subject: [PATCH] Better handling of flags --- src/lib.rs | 39 ++++++++++++++++++++++++++++++++++----- src/ndarray.rs | 30 +++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f7ab1c5..ced3ffb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -279,12 +279,20 @@ impl DLPackTensor { /// Get a mutable DLPack tensor reference from this owned tensor pub fn as_mut(&mut self) -> DLPackTensorRefMut<'_> { - unsafe { - // only if NOT read only + unique - assert!(self.raw.as_ref().flags & sys::DLPACK_FLAG_BITMASK_IS_COPIED == 0, "Can not create a mutable reference to a borrowed tensor"); - assert!(self.raw.as_ref().flags & sys::DLPACK_FLAG_BITMASK_READ_ONLY != 0, "Can not create a mutable reference to a read-only tensor"); + if self.is_read_only() { + panic!("Can not get a `DLPackTensorRefMut`: tensor is explicitly marked READ_ONLY"); + } - // SAFETY: we are constaining the returned reference lifetime + // FIXME: ideally we should also check the IS_COPIED/IS_OWNED bit to + // ensure that we are the unique owner of the tensor, but some libraries + // (e.g. PyTorch) don't set this bit correctly, so for now we just + // ignore it and let the user deal with potential issues if they mutate + // a non-unique tensor. + + unsafe { + // SAFETY: we are constraining the returned reference lifetime + // the caller must ensure that the uniqueness check doesn't apply + // i.e. they're fine mutating an ArcArray with refcount > 1 DLPackTensorRefMut::from_raw(self.raw.as_ref().dl_tensor.clone()) } } @@ -334,6 +342,27 @@ impl DLPackTensor { } } + /// Get the raw flags bitfield. + pub fn flags(&self) -> u64 { + unsafe { self.raw.as_ref().flags } + } + + /// Check if the tensor is explicitly marked as read-only. + pub fn is_read_only(&self) -> bool { + self.flags() & sys::DLPACK_FLAG_BITMASK_READ_ONLY != 0 + } + + /// Check if the tensor is unique/owned (IS_COPIED bit). + pub fn is_copied(&self) -> bool { + self.flags() & sys::DLPACK_FLAG_BITMASK_IS_COPIED != 0 + } + + /// Check if the sub-byte types (fp4, fp6) are padded to the next byte, or + /// packed together. + pub fn is_subbyte_type_padded(&self) -> bool { + self.flags() & sys::DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED != 0 + } + /// Get the byte offset of this tensor, i.e. how many bytes should be added /// to [`DLPackTensor::data_ptr`] and [`DLPackTensor::data_ptr_mut`] to get /// the first element of the tensor. diff --git a/src/ndarray.rs b/src/ndarray.rs index 174c335..5c218f2 100644 --- a/src/ndarray.rs +++ b/src/ndarray.rs @@ -417,7 +417,7 @@ where version: sys::DLPackVersion::current(), manager_ctx: Box::into_raw(ctx).cast(), deleter: Some(deleter_fn::>), - flags: 0, + flags: sys::DLPACK_FLAG_BITMASK_IS_COPIED, dl_tensor, }; @@ -682,4 +682,32 @@ mod tests { let shape = tensor_ref.shape(); assert_eq!(shape, &[2, 2]); } + + #[test] + fn test_array_conversion_permits_mutation() { + let array = arr2(&[[1.0f32, 2.0], [3.0, 4.0]]); + let mut tensor: DLPackTensor = array.try_into().unwrap(); + + // This should not panic because flags include IS_COPIED + // and do not include READ_ONLY. + let mut tensor_mut = tensor.as_mut(); + let ptr = tensor_mut.data_ptr_mut::().unwrap(); + + unsafe { + *ptr = 42.0; + } + + let val = tensor.as_ref().data_ptr::().unwrap(); + assert_eq!(unsafe { *val }, 42.0); + } + + #[test] + fn test_arc_array_conversion_allows_readonly_access() { + let array = ArcArray2::from_elem((2, 2), 1.0f32); + let tensor: DLPackTensor = (&array).try_into().unwrap(); + + // Standard immutable access should remain functional. + let tensor_ref = tensor.as_ref(); + assert_eq!(tensor_ref.dtype(), f32::get_dlpack_data_type()); + } }