diff --git a/src/guest_memory.rs b/src/guest_memory.rs index 251735ab..904cdc9a 100644 --- a/src/guest_memory.rs +++ b/src/guest_memory.rs @@ -654,15 +654,13 @@ pub trait GuestMemory { Ok(0) => return Ok(total), // made some progress Ok(len) => { - total = match total.checked_add(len) { - Some(x) if x < count => x, - Some(x) if x == count => return Ok(x), - _ => return Err(Error::CallbackOutOfRange), - }; - cur = match cur.overflowing_add(len as GuestUsize) { - (x @ GuestAddress(0), _) | (x, false) => x, - (_, true) => return Err(Error::GuestAddressOverflow), - }; + total = total.checked_add(len).ok_or(Error::CallbackOutOfRange)?; + if total == count { + return Ok(total); + } + cur = cur + .checked_add(len as GuestUsize) + .ok_or(Error::GuestAddressOverflow)?; } // error happened e => return e, @@ -1304,6 +1302,25 @@ mod tests { non_atomic_access_helper::() } + #[cfg(feature = "backend-mmap")] + #[test] + // This test makes sure that computation for regions that end at u64::MAX do not cause + // overflows. + fn test_region_at_usize_max() { + let region_len: usize = 0x1000; + let addr = GuestAddress(u64::MAX - region_len as u64); + let mem = GuestMemoryMmap::from_ranges(&[(addr, region_len)]).unwrap(); + let mut image = vec![0; region_len + 1]; + + // This access could cause an overflow if we wouldn't have a check for memory region length. + let invalid_access_size = region_len + 1; + let count = mem + .write_volatile_to(addr, &mut image.as_mut_slice(), invalid_access_size) + .unwrap(); + // We can only write maximum region_len bytes. + assert_eq!(count, region_len); + } + #[cfg(feature = "backend-mmap")] #[test] fn test_zero_length_accesses() {