|
| 1 | +// Copyright (c) 2023-2024 CMU Database Group |
| 2 | +// |
| 3 | +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at |
| 4 | +// https://opensource.org/licenses/MIT. |
| 5 | + |
| 6 | +//! optd's implementation of disjoint sets (union finds). It's send + sync + serializable. |
| 7 | +
|
| 8 | +use std::{collections::HashMap, hash::Hash}; |
| 9 | +#[derive(Clone, Default)] |
| 10 | +pub struct DisjointSets<T: Clone> { |
| 11 | + data_idx: HashMap<T, usize>, |
| 12 | + parents: Vec<usize>, |
| 13 | +} |
| 14 | + |
| 15 | +impl<T: Clone> std::fmt::Debug for DisjointSets<T> { |
| 16 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 17 | + write!(f, "DisjointSets") |
| 18 | + } |
| 19 | +} |
| 20 | + |
| 21 | +impl<T: Clone + Eq + PartialEq + Hash> DisjointSets<T> { |
| 22 | + pub fn new() -> Self { |
| 23 | + Self { |
| 24 | + data_idx: HashMap::new(), |
| 25 | + parents: Vec::new(), |
| 26 | + } |
| 27 | + } |
| 28 | + |
| 29 | + pub fn contains(&self, data: &T) -> bool { |
| 30 | + self.data_idx.contains_key(data) |
| 31 | + } |
| 32 | + |
| 33 | + #[must_use] |
| 34 | + pub fn make_set(&mut self, data: T) -> Option<()> { |
| 35 | + if self.data_idx.contains_key(&data) { |
| 36 | + return None; |
| 37 | + } |
| 38 | + let idx = self.parents.len(); |
| 39 | + self.data_idx.insert(data.clone(), idx); |
| 40 | + self.parents.push(idx); |
| 41 | + Some(()) |
| 42 | + } |
| 43 | + |
| 44 | + fn find(&mut self, mut idx: usize) -> usize { |
| 45 | + while self.parents[idx] != idx { |
| 46 | + self.parents[idx] = self.parents[self.parents[idx]]; |
| 47 | + idx = self.parents[idx]; |
| 48 | + } |
| 49 | + idx |
| 50 | + } |
| 51 | + |
| 52 | + fn find_const(&self, mut idx: usize) -> usize { |
| 53 | + while self.parents[idx] != idx { |
| 54 | + idx = self.parents[idx]; |
| 55 | + } |
| 56 | + idx |
| 57 | + } |
| 58 | + |
| 59 | + #[must_use] |
| 60 | + pub fn union(&mut self, data1: &T, data2: &T) -> Option<()> { |
| 61 | + let idx1 = *self.data_idx.get(data1)?; |
| 62 | + let idx2 = *self.data_idx.get(data2)?; |
| 63 | + let parent1 = self.find(idx1); |
| 64 | + let parent2 = self.find(idx2); |
| 65 | + if parent1 != parent2 { |
| 66 | + self.parents[parent1] = parent2; |
| 67 | + } |
| 68 | + Some(()) |
| 69 | + } |
| 70 | + |
| 71 | + pub fn same_set(&self, data1: &T, data2: &T) -> Option<bool> { |
| 72 | + let idx1 = *self.data_idx.get(data1)?; |
| 73 | + let idx2 = *self.data_idx.get(data2)?; |
| 74 | + Some(self.find_const(idx1) == self.find_const(idx2)) |
| 75 | + } |
| 76 | + |
| 77 | + pub fn set_size(&self, data: &T) -> Option<usize> { |
| 78 | + let idx = *self.data_idx.get(data)?; |
| 79 | + let parent = self.find_const(idx); |
| 80 | + Some( |
| 81 | + self.parents |
| 82 | + .iter() |
| 83 | + .filter(|&&x| self.find_const(x) == parent) |
| 84 | + .count(), |
| 85 | + ) |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +#[cfg(test)] |
| 90 | +mod tests { |
| 91 | + use super::*; |
| 92 | + #[test] |
| 93 | + fn test_union_find() { |
| 94 | + let mut set = DisjointSets::new(); |
| 95 | + set.make_set("a").unwrap(); |
| 96 | + set.make_set("b").unwrap(); |
| 97 | + set.make_set("c").unwrap(); |
| 98 | + set.make_set("d").unwrap(); |
| 99 | + set.make_set("e").unwrap(); |
| 100 | + assert!(set.same_set(&"a", &"a").unwrap()); |
| 101 | + assert!(!set.same_set(&"a", &"b").unwrap()); |
| 102 | + assert_eq!(set.set_size(&"a").unwrap(), 1); |
| 103 | + assert_eq!(set.set_size(&"c").unwrap(), 1); |
| 104 | + set.union(&"a", &"b").unwrap(); |
| 105 | + assert_eq!(set.set_size(&"a").unwrap(), 2); |
| 106 | + assert_eq!(set.set_size(&"c").unwrap(), 1); |
| 107 | + assert!(set.same_set(&"a", &"b").unwrap()); |
| 108 | + assert!(!set.same_set(&"a", &"c").unwrap()); |
| 109 | + set.union(&"b", &"c").unwrap(); |
| 110 | + assert!(set.same_set(&"a", &"c").unwrap()); |
| 111 | + assert!(!set.same_set(&"a", &"d").unwrap()); |
| 112 | + assert_eq!(set.set_size(&"a").unwrap(), 3); |
| 113 | + assert_eq!(set.set_size(&"d").unwrap(), 1); |
| 114 | + set.union(&"d", &"e").unwrap(); |
| 115 | + assert!(set.same_set(&"d", &"e").unwrap()); |
| 116 | + assert!(!set.same_set(&"a", &"d").unwrap()); |
| 117 | + assert_eq!(set.set_size(&"a").unwrap(), 3); |
| 118 | + assert_eq!(set.set_size(&"d").unwrap(), 2); |
| 119 | + set.union(&"c", &"e").unwrap(); |
| 120 | + assert!(set.same_set(&"a", &"e").unwrap()); |
| 121 | + assert_eq!(set.set_size(&"d").unwrap(), 5); |
| 122 | + } |
| 123 | +} |
0 commit comments