Skip to content

Commit 6321d93

Browse files
authored
Merge pull request #8 from cuviper/disjoint
Implement get_disjoint_mut
2 parents 4aac93f + 3d1d290 commit 6321d93

File tree

4 files changed

+359
-1
lines changed

4 files changed

+359
-1
lines changed

src/lib.rs

+30
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,33 @@ impl core::fmt::Display for TryReserveError {
266266
#[cfg(feature = "std")]
267267
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
268268
impl std::error::Error for TryReserveError {}
269+
270+
// NOTE: This is copied from the slice module in the std lib.
271+
/// The error type returned by [`get_disjoint_indices_mut`][`RingMap::get_disjoint_indices_mut`].
272+
///
273+
/// It indicates one of two possible errors:
274+
/// - An index is out-of-bounds.
275+
/// - The same index appeared multiple times in the array.
276+
// (or different but overlapping indices when ranges are provided)
277+
#[derive(Debug, Clone, PartialEq, Eq)]
278+
pub enum GetDisjointMutError {
279+
/// An index provided was out-of-bounds for the slice.
280+
IndexOutOfBounds,
281+
/// Two indices provided were overlapping.
282+
OverlappingIndices,
283+
}
284+
285+
impl core::fmt::Display for GetDisjointMutError {
286+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
287+
let msg = match self {
288+
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
289+
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
290+
};
291+
292+
core::fmt::Display::fmt(msg, f)
293+
}
294+
}
295+
296+
#[cfg(feature = "std")]
297+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
298+
impl std::error::Error for GetDisjointMutError {}

src/map.rs

+48-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use std::collections::hash_map::RandomState;
4141

4242
use self::core::RingMapCore;
4343
use crate::util::third;
44-
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
44+
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};
4545

4646
/// A hash table where the iteration order of the key-value pairs is independent
4747
/// of the hash values of the keys.
@@ -825,6 +825,33 @@ where
825825
}
826826
}
827827

828+
/// Return the values for `N` keys. If any key is duplicated, this function will panic.
829+
///
830+
/// # Examples
831+
///
832+
/// ```
833+
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
834+
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
835+
/// ```
836+
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
837+
where
838+
Q: ?Sized + Hash + Equivalent<K>,
839+
{
840+
let indices = keys.map(|key| self.get_index_of(key));
841+
let (head, tail) = self.as_mut_slices();
842+
match Slice::get_disjoint_opt_mut(head, tail, indices) {
843+
Err(GetDisjointMutError::IndexOutOfBounds) => {
844+
unreachable!(
845+
"Internal error: indices should never be OOB as we got them from get_index_of"
846+
);
847+
}
848+
Err(GetDisjointMutError::OverlappingIndices) => {
849+
panic!("duplicate keys found");
850+
}
851+
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
852+
}
853+
}
854+
828855
/// Remove the key-value pair equivalent to `key` and return its value.
829856
///
830857
/// Like [`VecDeque::remove`], the pair is removed by shifting all of the
@@ -1286,6 +1313,26 @@ impl<K, V, S> RingMap<K, V, S> {
12861313
Some(IndexedEntry::new(&mut self.core, index))
12871314
}
12881315

1316+
/// Get an array of `N` key-value pairs by `N` indices
1317+
///
1318+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
1319+
///
1320+
/// # Examples
1321+
///
1322+
/// ```
1323+
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
1324+
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
1325+
/// ```
1326+
pub fn get_disjoint_indices_mut<const N: usize>(
1327+
&mut self,
1328+
indices: [usize; N],
1329+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
1330+
let indices = indices.map(Some);
1331+
let (head, tail) = self.as_mut_slices();
1332+
let key_values = Slice::get_disjoint_opt_mut(head, tail, indices)?;
1333+
Ok(key_values.map(Option::unwrap))
1334+
}
1335+
12891336
/// Get the first key-value pair
12901337
///
12911338
/// Computes in **O(1)** time.

src/map/slice.rs

+56
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::{Bucket, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Values, ValuesMut};
22
use crate::util::{slice_eq, try_simplify_range};
3+
use crate::GetDisjointMutError;
34

45
use alloc::boxed::Box;
56
use alloc::collections::VecDeque;
@@ -264,6 +265,61 @@ impl<K, V> Slice<K, V> {
264265
self.entries
265266
.partition_point(move |a| pred(&a.key, &a.value))
266267
}
268+
269+
/// Get an array of `N` key-value pairs by `N` indices
270+
///
271+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
272+
pub fn get_disjoint_mut<const N: usize>(
273+
&mut self,
274+
indices: [usize; N],
275+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
276+
let indices = indices.map(Some);
277+
let empty_tail = Self::new_mut();
278+
let key_values = Self::get_disjoint_opt_mut(self, empty_tail, indices)?;
279+
Ok(key_values.map(Option::unwrap))
280+
}
281+
282+
#[allow(unsafe_code)]
283+
pub(crate) fn get_disjoint_opt_mut<'a, const N: usize>(
284+
head: &mut Self,
285+
tail: &mut Self,
286+
indices: [Option<usize>; N],
287+
) -> Result<[Option<(&'a K, &'a mut V)>; N], GetDisjointMutError> {
288+
let mid = head.len();
289+
let len = mid + tail.len();
290+
291+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
292+
for i in 0..N {
293+
if let Some(idx) = indices[i] {
294+
if idx >= len {
295+
return Err(GetDisjointMutError::IndexOutOfBounds);
296+
} else if indices[..i].contains(&Some(idx)) {
297+
return Err(GetDisjointMutError::OverlappingIndices);
298+
}
299+
}
300+
}
301+
302+
let head_ptr = head.entries.as_mut_ptr();
303+
let tail_ptr = tail.entries.as_mut_ptr();
304+
let out = indices.map(|idx_opt| {
305+
match idx_opt {
306+
Some(idx) => {
307+
// SAFETY: The base pointers are valid as they come from slices and the reference is always
308+
// in-bounds & unique as we've already checked the indices above.
309+
unsafe {
310+
let ptr = match idx.checked_sub(mid) {
311+
None => head_ptr.add(idx),
312+
Some(tidx) => tail_ptr.add(tidx),
313+
};
314+
Some((*ptr).ref_mut())
315+
}
316+
}
317+
None => None,
318+
}
319+
});
320+
321+
Ok(out)
322+
}
267323
}
268324

269325
impl<'a, K, V> IntoIterator for &'a Slice<K, V> {

0 commit comments

Comments
 (0)