Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@

- \[[#339](https://github.com/rust-vmm/vm-memory/pull/339)\] Fix `Bytes::read()` and `Bytes::write()` not to ignore `try_access()`'s `count` parameter

### Deprecated

- \[[#349](https://github.com/rust-vmm/vm-memory/pull/349)\] Deprecate `GuestMemory::try_access()`. Use `GuestMemory::get_slices()` instead.

## \[v0.16.1\]

### Added
Expand Down
2 changes: 1 addition & 1 deletion coverage_config_x86_64.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"coverage_score": 91.78,
"coverage_score": 90.82,
"exclude_path": "mmap_windows.rs",
"crate_features": "backend-mmap,backend-atomic,backend-bitmap"
}
12 changes: 3 additions & 9 deletions src/bitmap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,10 @@ pub(crate) mod tests {

// Finally, let's invoke the generic tests for `Bytes`.
let check_range_closure = |m: &M, start: usize, len: usize, clean: bool| -> bool {
let mut check_result = true;
m.try_access(len, GuestAddress(start as u64), |_, size, reg_addr, reg| {
if !check_range(&reg.bitmap(), reg_addr.0 as usize, size, clean) {
check_result = false;
}
Ok(size)
m.get_slices(GuestAddress(start as u64), len).all(|r| {
let slice = r.unwrap();
check_range(slice.bitmap(), 0, slice.len(), clean)
})
.unwrap();

check_result
};

test_bytes(
Expand Down
75 changes: 46 additions & 29 deletions src/guest_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,9 @@ pub trait GuestMemory {

/// Check whether the range [base, base + len) is valid.
fn check_range(&self, base: GuestAddress, len: usize) -> bool {
match self.try_access(len, base, |_, count, _, _| -> Result<usize> { Ok(count) }) {
Ok(count) => count == len,
_ => false,
}
// get_slices() ensures that if no error happens, the cumulative length of all slices
// equal `len`.
self.get_slices(base, len).all(|r| r.is_ok())
}

/// Returns the address plus the offset if it is present within the memory of the guest.
Expand All @@ -376,6 +375,10 @@ pub trait GuestMemory {
/// - the error code returned by the callback 'f'
/// - the size of the already handled data when encountering the first hole
/// - the size of the already handled data when the whole range has been handled
#[deprecated(
since = "0.17.0",
note = "supplemented by external iterator `get_slices()`"
)]
fn try_access<F>(&self, count: usize, addr: GuestAddress, mut f: F) -> Result<usize>
where
F: FnMut(usize, usize, MemoryRegionAddress, &Self::R) -> Result<usize>,
Expand Down Expand Up @@ -517,15 +520,32 @@ impl<'a, M: GuestMemory + ?Sized> GuestMemorySliceIterator<'a, M> {
};

let cap = region.len() - start.raw_value();
let len = std::cmp::min(cap, self.count as GuestUsize);
let len = std::cmp::min(cap as usize, self.count);

self.count -= len as usize;
self.count -= len;
self.addr = match self.addr.overflowing_add(len as GuestUsize) {
(x @ GuestAddress(0), _) | (x, false) => x,
(_, true) => return Some(Err(Error::GuestAddressOverflow)),
};

Some(region.get_slice(start, len as usize))
Some(region.get_slice(start, len).inspect(|s| {
assert_eq!(
s.len(),
len,
"get_slice() returned a slice with wrong length"
)
}))
}

/// Adapts this [`GuestMemorySliceIterator`] to return `None` (e.g. gracefully terminate)
/// when it encounters an error after successfully producing at least one slice.
/// Return an error if requesting the first slice returns an error.
pub fn stop_on_error(self) -> Result<impl Iterator<Item = VolatileSlice<'a, MS<'a, M>>>> {
let mut peek = self.peekable();
if let Some(err) = peek.next_if(Result::is_err) {
return Err(err.unwrap_err());
}
Ok(peek.filter_map(Result::ok))
}
}

Expand Down Expand Up @@ -556,23 +576,15 @@ impl<T: GuestMemory + ?Sized> Bytes<GuestAddress> for T {
type E = Error;

fn write(&self, buf: &[u8], addr: GuestAddress) -> Result<usize> {
self.try_access(
buf.len(),
addr,
|offset, count, caddr, region| -> Result<usize> {
region.write(&buf[offset..(offset + count)], caddr)
},
)
self.get_slices(addr, buf.len())
.stop_on_error()?
.try_fold(0, |acc, slice| Ok(acc + slice.write(&buf[acc..], 0)?))
}

fn read(&self, buf: &mut [u8], addr: GuestAddress) -> Result<usize> {
self.try_access(
buf.len(),
addr,
|offset, count, caddr, region| -> Result<usize> {
region.read(&mut buf[offset..(offset + count)], caddr)
},
)
self.get_slices(addr, buf.len())
.stop_on_error()?
.try_fold(0, |acc, slice| Ok(acc + slice.read(&mut buf[acc..], 0)?))
}

/// # Examples
Expand Down Expand Up @@ -636,9 +648,11 @@ impl<T: GuestMemory + ?Sized> Bytes<GuestAddress> for T {
where
F: ReadVolatile,
{
self.try_access(count, addr, |_, len, caddr, region| -> Result<usize> {
region.read_volatile_from(caddr, src, len)
})
self.get_slices(addr, count)
.stop_on_error()?
.try_fold(0, |acc, slice| {
Ok(acc + slice.read_volatile_from(0, src, slice.len())?)
})
}

fn read_exact_volatile_from<F>(
Expand All @@ -664,11 +678,14 @@ impl<T: GuestMemory + ?Sized> Bytes<GuestAddress> for T {
where
F: WriteVolatile,
{
self.try_access(count, addr, |_, len, caddr, region| -> Result<usize> {
// For a non-RAM region, reading could have side effects, so we
// must use write_all().
region.write_all_volatile_to(caddr, dst, len).map(|()| len)
})
self.get_slices(addr, count)
.stop_on_error()?
.try_fold(0, |acc, slice| {
// For a non-RAM region, reading could have side effects, so we
// must use write_all().
slice.write_all_volatile_to(0, dst, slice.len())?;
Ok(acc + slice.len())
})
}

fn write_all_volatile_to<F>(&self, addr: GuestAddress, dst: &mut F, count: usize) -> Result<()>
Expand Down
24 changes: 24 additions & 0 deletions src/mmap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,29 @@ mod tests {
}
}

#[test]
fn test_check_range() {
let start_addr1 = GuestAddress(0);
let start_addr2 = GuestAddress(0x800);
let start_addr3 = GuestAddress(0xc00);
let guest_mem = GuestMemoryMmap::from_ranges(&[
(start_addr1, 0x400),
(start_addr2, 0x400),
(start_addr3, 0x400),
])
.unwrap();

assert!(guest_mem.check_range(start_addr1, 0x0));
assert!(guest_mem.check_range(start_addr1, 0x200));
assert!(guest_mem.check_range(start_addr1, 0x400));
assert!(!guest_mem.check_range(start_addr1, 0xa00));
assert!(guest_mem.check_range(start_addr2, 0x7ff));
assert!(guest_mem.check_range(start_addr2, 0x800));
assert!(!guest_mem.check_range(start_addr2, 0x801));
assert!(!guest_mem.check_range(start_addr2, 0xc00));
assert!(!guest_mem.check_range(start_addr1, usize::MAX));
}

#[test]
fn test_deref() {
let f = TempFile::new().unwrap().into_file();
Expand Down Expand Up @@ -432,6 +455,7 @@ mod tests {

#[test]
#[cfg(feature = "rawfd")]
#[cfg(not(miri))]
fn read_to_and_write_from_mem() {
use std::mem;

Expand Down
23 changes: 0 additions & 23 deletions src/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,27 +773,4 @@ pub(crate) mod tests {
Some(GuestAddress(0x400 - 1))
);
}

#[test]
fn test_check_range() {
let start_addr1 = GuestAddress(0);
let start_addr2 = GuestAddress(0x800);
let start_addr3 = GuestAddress(0xc00);
let guest_mem = new_guest_memory_collection_from_regions(&[
(start_addr1, 0x400),
(start_addr2, 0x400),
(start_addr3, 0x400),
])
.unwrap();

assert!(guest_mem.check_range(start_addr1, 0x0));
assert!(guest_mem.check_range(start_addr1, 0x200));
assert!(guest_mem.check_range(start_addr1, 0x400));
assert!(!guest_mem.check_range(start_addr1, 0xa00));
assert!(guest_mem.check_range(start_addr2, 0x7ff));
assert!(guest_mem.check_range(start_addr2, 0x800));
assert!(!guest_mem.check_range(start_addr2, 0x801));
assert!(!guest_mem.check_range(start_addr2, 0xc00));
assert!(!guest_mem.check_range(start_addr1, usize::MAX));
}
}