Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,20 @@ impl DLPackTensor {

/// Get a mutable DLPack tensor reference from this owned tensor
pub fn as_mut(&mut self) -> DLPackTensorRefMut<'_> {
unsafe {
// only if NOT read only + unique
assert!(self.raw.as_ref().flags & sys::DLPACK_FLAG_BITMASK_IS_COPIED == 0, "Can not create a mutable reference to a borrowed tensor");
assert!(self.raw.as_ref().flags & sys::DLPACK_FLAG_BITMASK_READ_ONLY != 0, "Can not create a mutable reference to a read-only tensor");
if self.is_read_only() {
panic!("Can not get a `DLPackTensorRefMut`: tensor is explicitly marked READ_ONLY");
}

// SAFETY: we are constaining the returned reference lifetime
// FIXME: ideally we should also check the IS_COPIED/IS_OWNED bit to
// ensure that we are the unique owner of the tensor, but some libraries
// (e.g. PyTorch) don't set this bit correctly, so for now we just
// ignore it and let the user deal with potential issues if they mutate
// a non-unique tensor.

unsafe {
// SAFETY: we are constraining the returned reference lifetime
// the caller must ensure that the uniqueness check doesn't apply
// i.e. they're fine mutating an ArcArray with refcount > 1
DLPackTensorRefMut::from_raw(self.raw.as_ref().dl_tensor.clone())
}
}
Expand Down Expand Up @@ -334,6 +342,27 @@ impl DLPackTensor {
}
}

/// Get the raw flags bitfield.
pub fn flags(&self) -> u64 {
unsafe { self.raw.as_ref().flags }
}

/// Check if the tensor is explicitly marked as read-only.
pub fn is_read_only(&self) -> bool {
self.flags() & sys::DLPACK_FLAG_BITMASK_READ_ONLY != 0
}

/// Check if the tensor is unique/owned (IS_COPIED bit).
pub fn is_copied(&self) -> bool {
self.flags() & sys::DLPACK_FLAG_BITMASK_IS_COPIED != 0
}

/// Check if the sub-byte types (fp4, fp6) are padded to the next byte, or
/// packed together.
pub fn is_subbyte_type_padded(&self) -> bool {
self.flags() & sys::DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED != 0
}

/// Get the byte offset of this tensor, i.e. how many bytes should be added
/// to [`DLPackTensor::data_ptr`] and [`DLPackTensor::data_ptr_mut`] to get
/// the first element of the tensor.
Expand Down
30 changes: 29 additions & 1 deletion src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ where
version: sys::DLPackVersion::current(),
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(deleter_fn::<Array<T, D>>),
flags: 0,
flags: sys::DLPACK_FLAG_BITMASK_IS_COPIED,
dl_tensor,
};

Expand Down Expand Up @@ -682,4 +682,32 @@ mod tests {
let shape = tensor_ref.shape();
assert_eq!(shape, &[2, 2]);
}

#[test]
fn test_array_conversion_permits_mutation() {
let array = arr2(&[[1.0f32, 2.0], [3.0, 4.0]]);
let mut tensor: DLPackTensor = array.try_into().unwrap();

// This should not panic because flags include IS_COPIED
// and do not include READ_ONLY.
let mut tensor_mut = tensor.as_mut();
let ptr = tensor_mut.data_ptr_mut::<f32>().unwrap();

unsafe {
*ptr = 42.0;
}

let val = tensor.as_ref().data_ptr::<f32>().unwrap();
assert_eq!(unsafe { *val }, 42.0);
}

#[test]
fn test_arc_array_conversion_allows_readonly_access() {
let array = ArcArray2::from_elem((2, 2), 1.0f32);
let tensor: DLPackTensor = (&array).try_into().unwrap();

// Standard immutable access should remain functional.
let tensor_ref = tensor.as_ref();
assert_eq!(tensor_ref.dtype(), f32::get_dlpack_data_type());
}
}
Loading