Skip to content

Commit 11017c9

Browse files
committed
Address review feedback
1 parent c6f35fd commit 11017c9

File tree

4 files changed

+205
-66
lines changed

4 files changed

+205
-66
lines changed

src/lib.rs

+30
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,33 @@ impl core::fmt::Display for TryReserveError {
268268
#[cfg(feature = "std")]
269269
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
270270
impl std::error::Error for TryReserveError {}
271+
272+
// NOTE: This is copied from the slice module in the std lib.
273+
/// The error type returned by [`get_disjoint_indices_mut`][`IndexMap::get_disjoint_indices_mut`].
274+
///
275+
/// It indicates one of two possible errors:
276+
/// - An index is out-of-bounds.
277+
/// - The same index appeared multiple times in the array
278+
/// (or different but overlapping indices when ranges are provided).
279+
#[derive(Debug, Clone, PartialEq, Eq)]
280+
pub enum GetDisjointMutError {
281+
/// An index provided was out-of-bounds for the slice.
282+
IndexOutOfBounds,
283+
/// Two indices provided were overlapping.
284+
OverlappingIndices,
285+
}
286+
287+
impl core::fmt::Display for GetDisjointMutError {
288+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
289+
let msg = match self {
290+
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
291+
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
292+
};
293+
294+
core::fmt::Display::fmt(msg, f)
295+
}
296+
}
297+
298+
#[cfg(feature = "std")]
299+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
300+
impl std::error::Error for GetDisjointMutError {}

src/map.rs

+18-43
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use std::collections::hash_map::RandomState;
3838

3939
use self::core::IndexMapCore;
4040
use crate::util::{third, try_simplify_range};
41-
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
41+
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};
4242

4343
/// A hash table where the iteration order of the key-value pairs is independent
4444
/// of the hash values of the keys.
@@ -790,35 +790,31 @@ where
790790
}
791791
}
792792

793-
/// Return the values for `N` keys. If any key is missing a value, or there
794-
/// are duplicate keys, `None` is returned.
793+
/// Return the values for `N` keys. If any key is duplicated, this function will panic.
795794
///
796795
/// # Examples
797796
///
798797
/// ```
799798
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
800-
/// assert_eq!(map.get_disjoint_mut([&2, &1]), Some([&mut 'c', &mut 'a']));
799+
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
801800
/// ```
802-
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> Option<[&mut V; N]>
801+
#[allow(unsafe_code)]
802+
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
803803
where
804804
Q: Hash + Equivalent<K> + ?Sized,
805805
{
806-
let len = self.len();
807806
let indices = keys.map(|key| self.get_index_of(key));
808-
809-
// Handle out-of-bounds indices with panic as this is an internal error in get_index_of.
810-
for idx in indices {
811-
let idx = idx?;
812-
debug_assert!(
813-
idx < len,
814-
"Index is out of range! Got '{}' but length is '{}'",
815-
idx,
816-
len
817-
);
807+
match self.as_mut_slice().get_disjoint_opt_mut(indices) {
808+
Err(GetDisjointMutError::IndexOutOfBounds) => {
809+
unreachable!(
810+
"Internal error: indices should never be OOB as we got them from get_index_of"
811+
);
812+
}
813+
Err(GetDisjointMutError::OverlappingIndices) => {
814+
panic!("duplicate keys found");
815+
}
816+
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
818817
}
819-
let indices = indices.map(Option::unwrap);
820-
let entries = self.get_disjoint_indices_mut(indices)?;
821-
Some(entries.map(|(_key, value)| value))
822818
}
823819

824820
/// Remove the key-value pair equivalent to `key` and return
@@ -1231,38 +1227,17 @@ impl<K, V, S> IndexMap<K, V, S> {
12311227
///
12321228
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
12331229
///
1234-
/// Computes in **O(1)** time.
1235-
///
12361230
/// # Examples
12371231
///
12381232
/// ```
12391233
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
1240-
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Some([(&2, &mut 'c'), (&1, &mut 'a')]));
1234+
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
12411235
/// ```
12421236
pub fn get_disjoint_indices_mut<const N: usize>(
12431237
&mut self,
12441238
indices: [usize; N],
1245-
) -> Option<[(&K, &mut V); N]> {
1246-
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
1247-
let len = self.len();
1248-
for i in 0..N {
1249-
let idx = indices[i];
1250-
if idx >= len || indices[i + 1..N].contains(&idx) {
1251-
return None;
1252-
}
1253-
}
1254-
1255-
let entries_ptr = self.as_entries_mut().as_mut_ptr();
1256-
let out = indices.map(|i| {
1257-
// SAFETY: The base pointer is valid as it comes from a slice and the deref is always
1258-
// in-bounds as we've already checked the indices above.
1259-
#[allow(unsafe_code)]
1260-
unsafe {
1261-
(*(entries_ptr.add(i))).ref_mut()
1262-
}
1263-
});
1264-
1265-
Some(out)
1239+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
1240+
self.as_mut_slice().get_disjoint_mut(indices)
12661241
}
12671242

12681243
/// Returns a slice of key-value pairs in the given range of indices.

src/map/slice.rs

+47
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use super::{
33
ValuesMut,
44
};
55
use crate::util::try_simplify_range;
6+
use crate::GetDisjointMutError;
67

78
use alloc::boxed::Box;
89
use alloc::vec::Vec;
@@ -270,6 +271,52 @@ impl<K, V> Slice<K, V> {
270271
self.entries
271272
.partition_point(move |a| pred(&a.key, &a.value))
272273
}
274+
275+
/// Get an array of `N` key-value pairs by `N` indices
276+
///
277+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
278+
pub fn get_disjoint_mut<const N: usize>(
279+
&mut self,
280+
indices: [usize; N],
281+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
282+
let indices = indices.map(Some);
283+
let key_values = self.get_disjoint_opt_mut(indices)?;
284+
Ok(key_values.map(Option::unwrap))
285+
}
286+
287+
#[allow(unsafe_code)]
288+
pub(crate) fn get_disjoint_opt_mut<const N: usize>(
289+
&mut self,
290+
indices: [Option<usize>; N],
291+
) -> Result<[Option<(&K, &mut V)>; N], GetDisjointMutError> {
292+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
293+
let len = self.len();
294+
for i in 0..N {
295+
let Some(idx) = indices[i] else {
296+
continue;
297+
};
298+
if idx >= len {
299+
return Err(GetDisjointMutError::IndexOutOfBounds);
300+
} else if indices[i + 1..N].contains(&Some(idx)) {
301+
return Err(GetDisjointMutError::OverlappingIndices);
302+
}
303+
}
304+
305+
let entries_ptr = self.entries.as_mut_ptr();
306+
let out = indices.map(|idx_opt| {
307+
match idx_opt {
308+
Some(idx) => {
309+
// SAFETY: The base pointer is valid as it comes from a slice and the reference is always
310+
// in-bounds & unique as we've already checked the indices above.
311+
let kv = unsafe { (*(entries_ptr.add(idx))).ref_mut() };
312+
Some(kv)
313+
}
314+
None => None,
315+
}
316+
});
317+
318+
Ok(out)
319+
}
273320
}
274321

275322
impl<'a, K, V> IntoIterator for &'a Slice<K, V> {

src/map/tests.rs

+110-23
Original file line numberDiff line numberDiff line change
@@ -832,28 +832,31 @@ move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0);
832832
#[test]
833833
fn disjoint_mut_empty_map() {
834834
let mut map: IndexMap<u32, u32> = IndexMap::default();
835-
assert!(map.get_disjoint_mut([&0, &1, &2, &3]).is_none());
835+
assert_eq!(
836+
map.get_disjoint_mut([&0, &1, &2, &3]),
837+
[None, None, None, None]
838+
);
836839
}
837840

838841
#[test]
839842
fn disjoint_mut_empty_param() {
840843
let mut map: IndexMap<u32, u32> = IndexMap::default();
841844
map.insert(1, 10);
842-
assert!(map.get_disjoint_mut([] as [&u32; 0]).is_some());
845+
assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []);
843846
}
844847

845848
#[test]
846849
fn disjoint_mut_single_fail() {
847850
let mut map: IndexMap<u32, u32> = IndexMap::default();
848851
map.insert(1, 10);
849-
assert!(map.get_disjoint_mut([&0]).is_none());
852+
assert_eq!(map.get_disjoint_mut([&0]), [None]);
850853
}
851854

852855
#[test]
853856
fn disjoint_mut_single_success() {
854857
let mut map: IndexMap<u32, u32> = IndexMap::default();
855858
map.insert(1, 10);
856-
assert_eq!(map.get_disjoint_mut([&1]), Some([&mut 10]));
859+
assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]);
857860
}
858861

859862
#[test]
@@ -863,11 +866,22 @@ fn disjoint_mut_multi_success() {
863866
map.insert(2, 200);
864867
map.insert(3, 300);
865868
map.insert(4, 400);
866-
assert_eq!(map.get_disjoint_mut([&1, &2]), Some([&mut 100, &mut 200]));
867-
assert_eq!(map.get_disjoint_mut([&1, &3]), Some([&mut 100, &mut 300]));
869+
assert_eq!(
870+
map.get_disjoint_mut([&1, &2]),
871+
[Some(&mut 100), Some(&mut 200)]
872+
);
873+
assert_eq!(
874+
map.get_disjoint_mut([&1, &3]),
875+
[Some(&mut 100), Some(&mut 300)]
876+
);
868877
assert_eq!(
869878
map.get_disjoint_mut([&3, &1, &4, &2]),
870-
Some([&mut 300, &mut 100, &mut 400, &mut 200])
879+
[
880+
Some(&mut 300),
881+
Some(&mut 100),
882+
Some(&mut 400),
883+
Some(&mut 200)
884+
]
871885
);
872886
}
873887

@@ -878,44 +892,117 @@ fn disjoint_mut_multi_success_unsized_key() {
878892
map.insert("2", 200);
879893
map.insert("3", 300);
880894
map.insert("4", 400);
881-
assert_eq!(map.get_disjoint_mut(["1", "2"]), Some([&mut 100, &mut 200]));
882-
assert_eq!(map.get_disjoint_mut(["1", "3"]), Some([&mut 100, &mut 300]));
895+
896+
assert_eq!(
897+
map.get_disjoint_mut(["1", "2"]),
898+
[Some(&mut 100), Some(&mut 200)]
899+
);
900+
assert_eq!(
901+
map.get_disjoint_mut(["1", "3"]),
902+
[Some(&mut 100), Some(&mut 300)]
903+
);
883904
assert_eq!(
884905
map.get_disjoint_mut(["3", "1", "4", "2"]),
885-
Some([&mut 300, &mut 100, &mut 400, &mut 200])
906+
[
907+
Some(&mut 300),
908+
Some(&mut 100),
909+
Some(&mut 400),
910+
Some(&mut 200)
911+
]
912+
);
913+
}
914+
915+
#[test]
916+
fn disjoint_mut_multi_success_borrow_key() {
917+
let mut map: IndexMap<String, u32> = IndexMap::default();
918+
map.insert("1".into(), 100);
919+
map.insert("2".into(), 200);
920+
map.insert("3".into(), 300);
921+
map.insert("4".into(), 400);
922+
923+
assert_eq!(
924+
map.get_disjoint_mut(["1", "2"]),
925+
[Some(&mut 100), Some(&mut 200)]
926+
);
927+
assert_eq!(
928+
map.get_disjoint_mut(["1", "3"]),
929+
[Some(&mut 100), Some(&mut 300)]
930+
);
931+
assert_eq!(
932+
map.get_disjoint_mut(["3", "1", "4", "2"]),
933+
[
934+
Some(&mut 300),
935+
Some(&mut 100),
936+
Some(&mut 400),
937+
Some(&mut 200)
938+
]
886939
);
887940
}
888941

889942
#[test]
890943
fn disjoint_mut_multi_fail_missing() {
944+
let mut map: IndexMap<u32, u32> = IndexMap::default();
945+
map.insert(1, 100);
946+
map.insert(2, 200);
947+
map.insert(3, 300);
948+
map.insert(4, 400);
949+
950+
assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]);
951+
assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]);
952+
assert_eq!(
953+
map.get_disjoint_mut([&1, &5, &4]),
954+
[Some(&mut 100), None, Some(&mut 400)]
955+
);
956+
}
957+
958+
#[test]
959+
#[should_panic]
960+
fn disjoint_mut_multi_fail_duplicate_panic() {
961+
let mut map: IndexMap<u32, u32> = IndexMap::default();
962+
map.insert(1, 100);
963+
map.get_disjoint_mut([&1, &2, &1]);
964+
}
965+
966+
#[test]
967+
fn disjoint_indices_mut_fail_oob() {
968+
let mut map: IndexMap<u32, u32> = IndexMap::default();
969+
map.insert(1, 10);
970+
map.insert(321, 20);
971+
assert_eq!(
972+
map.get_disjoint_indices_mut([1, 3]),
973+
Err(crate::GetDisjointMutError::IndexOutOfBounds)
974+
);
975+
}
976+
977+
#[test]
978+
fn disjoint_indices_mut_empty() {
891979
let mut map: IndexMap<u32, u32> = IndexMap::default();
892980
map.insert(1, 10);
893-
map.insert(1123, 100);
894981
map.insert(321, 20);
895-
map.insert(1337, 30);
896-
assert_eq!(map.get_disjoint_mut([&121, &1123]), None);
897-
assert_eq!(map.get_disjoint_mut([&1, &1337, &56]), None);
898-
assert_eq!(map.get_disjoint_mut([&1337, &123, &321, &1, &1123]), None);
982+
assert_eq!(map.get_disjoint_indices_mut([]), Ok([]));
899983
}
900984

901985
#[test]
902-
fn disjoint_mut_multi_fail_duplicate() {
986+
fn disjoint_indices_mut_success() {
903987
let mut map: IndexMap<u32, u32> = IndexMap::default();
904988
map.insert(1, 10);
905-
map.insert(1123, 100);
906989
map.insert(321, 20);
907-
map.insert(1337, 30);
908-
assert_eq!(map.get_disjoint_mut([&1, &1]), None);
990+
assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)]));
991+
992+
assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)]));
909993
assert_eq!(
910-
map.get_disjoint_mut([&1337, &123, &321, &1337, &1, &1123]),
911-
None
994+
map.get_disjoint_indices_mut([0, 1]),
995+
Ok([(&1, &mut 10), (&321, &mut 20)])
912996
);
913997
}
914998

915999
#[test]
916-
fn many_index_mut_fail_oob() {
1000+
fn disjoint_indices_mut_fail_duplicate() {
9171001
let mut map: IndexMap<u32, u32> = IndexMap::default();
9181002
map.insert(1, 10);
9191003
map.insert(321, 20);
920-
assert_eq!(map.get_disjoint_indices_mut([1, 3]), None);
1004+
assert_eq!(
1005+
map.get_disjoint_indices_mut([1, 2, 1]),
1006+
Err(crate::GetDisjointMutError::OverlappingIndices)
1007+
);
9211008
}

0 commit comments

Comments
 (0)