diff --git a/library/alloc/src/collections/btree/set.rs b/library/alloc/src/collections/btree/set.rs index 973e7c660670c..98ff028c6f78f 100644 --- a/library/alloc/src/collections/btree/set.rs +++ b/library/alloc/src/collections/btree/set.rs @@ -1545,6 +1545,37 @@ impl Sub<&BTreeSet> for &BTreeSet Sub<&BTreeSet> for BTreeSet { + type Output = BTreeSet; + + /// Returns the difference of `self` and `rhs` as a new `BTreeSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::BTreeSet; + /// + /// let a = BTreeSet::from([1, 2, 3]); + /// let b = BTreeSet::from([3, 4, 5]); + /// + /// let result = a - &b; + /// assert_eq!(result, BTreeSet::from([1, 2])); + /// ``` + fn sub(mut self, rhs: &BTreeSet) -> BTreeSet { + // Iterate the smaller set, removing elements that are in `rhs` from `self` + if self.len() <= rhs.len() { + self.retain(|e| !rhs.contains(e)); + } else { + rhs.iter().for_each(|e| { + self.remove(e); + }) + } + + self + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl BitXor<&BTreeSet> for &BTreeSet { type Output = BTreeSet; @@ -1570,6 +1601,37 @@ impl BitXor<&BTreeSet> for &BTreeSet } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitXor> for BTreeSet { + type Output = BTreeSet; + + /// Returns the symmetric difference of `self` and `rhs` as a new `BTreeSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::BTreeSet; + /// + /// let a = BTreeSet::from([1, 2, 3]); + /// let b = BTreeSet::from([2, 3, 4]); + /// + /// let result = a ^ b; + /// assert_eq!(result, BTreeSet::from([1, 4])); + /// ``` + fn bitxor(self, rhs: BTreeSet) -> BTreeSet { + // Iterate through the smaller set + let [mut a, mut b] = minmax_by_key(self, rhs, BTreeSet::len); + + // This is essentially + // a = a - b (retain elements that are *not* in b) + // b = b - a (remove all elements that are in a) + a.retain(|e| !b.remove(e)); + + // Union of the differences + a | b + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl BitAnd<&BTreeSet> for &BTreeSet { type Output = BTreeSet; @@ -1595,6 +1657,29 @@ impl BitAnd<&BTreeSet> for &BTreeSet } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitAnd<&BTreeSet> for BTreeSet { + type Output = BTreeSet; + + /// Returns the intersection of `self` and `rhs` as a new `BTreeSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::BTreeSet; + /// + /// let a = BTreeSet::from([1, 2, 3]); + /// let b = BTreeSet::from([2, 3, 4]); + /// + /// let result = a & &b; + /// assert_eq!(result, BTreeSet::from([2, 3])); + /// ``` + fn bitand(mut self, rhs: &BTreeSet) -> BTreeSet { + self.retain(|e| rhs.contains(e)); + self + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl BitOr<&BTreeSet> for &BTreeSet { type Output = BTreeSet; @@ -1620,6 +1705,33 @@ impl BitOr<&BTreeSet> for &BTreeSet< } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitOr> for BTreeSet { + type Output = BTreeSet; + + /// Returns the union of `self` and `rhs` as a new `BTreeSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::BTreeSet; + /// + /// let a = BTreeSet::from([1, 2, 3]); + /// let b = BTreeSet::from([3, 4, 5]); + /// + /// let result = a | b; + /// assert_eq!(result, BTreeSet::from([1, 2, 3, 4, 5])); + /// ``` + fn bitor(self, rhs: BTreeSet) -> BTreeSet { + // Try to avoid unnecessary moves, by keeping set with the bigger length + let [a, mut b] = minmax_by_key(self, rhs, BTreeSet::len); + + b.extend(a); + + b + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl Debug for BTreeSet { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -2397,5 +2509,9 @@ impl<'a, T: Ord, A: Allocator + Clone> CursorMutKey<'a, T, A> { #[unstable(feature = "btree_cursors", issue = "107540")] pub use super::map::UnorderedKeyError; +fn minmax_by_key(a: T, b: T, k: impl Fn(&T) -> K) -> [T; 2] { + if k(&a) <= k(&b) { [a, b] } else { [b, a] } +} + #[cfg(test)] mod tests; diff --git a/library/std/src/collections/hash/set.rs b/library/std/src/collections/hash/set.rs index d611353b0d3f2..a3b9188261c38 100644 --- a/library/std/src/collections/hash/set.rs +++ b/library/std/src/collections/hash/set.rs @@ -1139,21 +1139,46 @@ where /// let a = HashSet::from([1, 2, 3]); /// let b = HashSet::from([3, 4, 5]); /// - /// let set = &a | &b; - /// - /// let mut i = 0; - /// let expected = [1, 2, 3, 4, 5]; - /// for x in &set { - /// assert!(expected.contains(x)); - /// i += 1; - /// } - /// assert_eq!(i, expected.len()); + /// let result = &a | &b; + /// assert_eq!(result, HashSet::from([1, 2, 3, 4, 5])); /// ``` fn bitor(self, rhs: &HashSet) -> HashSet { self.union(rhs).cloned().collect() } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitOr> for HashSet +where + T: Eq + Hash, + S: BuildHasher, +{ + type Output = HashSet; + + /// Returns the union of `self` and `rhs` as a new `HashSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashSet; + /// + /// let a = HashSet::from([1, 2, 3]); + /// let b = HashSet::from([3, 4, 5]); + /// + /// let result = a | b; + /// assert_eq!(result, HashSet::from([1, 2, 3, 4, 5])); + /// ``` + fn bitor(self, rhs: HashSet) -> HashSet { + // Try to avoid allocations by keeping set with the bigger capacity, + // try to avoid unnecessary moves, by keeping set with the bigger length + let [a, mut b] = minmax_by_key(self, rhs, |set| (set.capacity(), set.len())); + + b.extend(a); + + b + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl BitAnd<&HashSet> for &HashSet where @@ -1172,21 +1197,41 @@ where /// let a = HashSet::from([1, 2, 3]); /// let b = HashSet::from([2, 3, 4]); /// - /// let set = &a & &b; - /// - /// let mut i = 0; - /// let expected = [2, 3]; - /// for x in &set { - /// assert!(expected.contains(x)); - /// i += 1; - /// } - /// assert_eq!(i, expected.len()); + /// let result = &a & &b; + /// assert_eq!(result, HashSet::from([2, 3])); /// ``` fn bitand(self, rhs: &HashSet) -> HashSet { self.intersection(rhs).cloned().collect() } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitAnd<&HashSet> for HashSet +where + T: Eq + Hash, + S: BuildHasher, +{ + type Output = HashSet; + + /// Returns the intersection of `self` and `rhs` as a new `HashSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashSet; + /// + /// let a = HashSet::from([1, 2, 3]); + /// let b = HashSet::from([2, 3, 4]); + /// + /// let result = a & &b; + /// assert_eq!(result, HashSet::from([2, 3])); + /// ``` + fn bitand(mut self, rhs: &HashSet) -> HashSet { + self.retain(|e| rhs.contains(e)); + self + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl BitXor<&HashSet> for &HashSet where @@ -1205,21 +1250,49 @@ where /// let a = HashSet::from([1, 2, 3]); /// let b = HashSet::from([3, 4, 5]); /// - /// let set = &a ^ &b; - /// - /// let mut i = 0; - /// let expected = [1, 2, 4, 5]; - /// for x in &set { - /// assert!(expected.contains(x)); - /// i += 1; - /// } - /// assert_eq!(i, expected.len()); + /// let result = &a ^ &b; + /// assert_eq!(result, HashSet::from([1, 2, 4, 5])); /// ``` fn bitxor(self, rhs: &HashSet) -> HashSet { self.symmetric_difference(rhs).cloned().collect() } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl BitXor> for HashSet +where + T: Eq + Hash, + S: BuildHasher, +{ + type Output = HashSet; + + /// Returns the symmetric difference of `self` and `rhs` as a new `HashSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashSet; + /// + /// let a = HashSet::from([1, 2, 3]); + /// let b = HashSet::from([3, 4, 5]); + /// + /// let result = a ^ b; + /// assert_eq!(result, HashSet::from([1, 2, 4, 5])); + /// ``` + fn bitxor(self, rhs: HashSet) -> HashSet { + // Iterate through the smaller set + let [mut a, mut b] = minmax_by_key(self, rhs, HashSet::len); + + // This is essentially + // a = a - b (retain elements that are *not* in b) + // b = b - a (remove all elements that are in a) + a.retain(|e| !b.remove(e)); + + // Union of the differences + a | b + } +} + #[stable(feature = "rust1", since = "1.0.0")] impl Sub<&HashSet> for &HashSet where @@ -1238,21 +1311,49 @@ where /// let a = HashSet::from([1, 2, 3]); /// let b = HashSet::from([3, 4, 5]); /// - /// let set = &a - &b; - /// - /// let mut i = 0; - /// let expected = [1, 2]; - /// for x in &set { - /// assert!(expected.contains(x)); - /// i += 1; - /// } - /// assert_eq!(i, expected.len()); + /// let result = &a - &b; + /// assert_eq!(result, HashSet::from([1, 2])); /// ``` fn sub(self, rhs: &HashSet) -> HashSet { self.difference(rhs).cloned().collect() } } +#[stable(feature = "set_owned_ops", since = "CURRENT_RUSTC_VERSION")] +impl Sub<&HashSet> for HashSet +where + T: Eq + Hash, + S: BuildHasher, +{ + type Output = HashSet; + + /// Returns the difference of `self` and `rhs` as a new `HashSet`. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashSet; + /// + /// let a = HashSet::from([1, 2, 3]); + /// let b = HashSet::from([3, 4, 5]); + /// + /// let result = a - &b; + /// assert_eq!(result, HashSet::from([1, 2])); + /// ``` + fn sub(mut self, rhs: &HashSet) -> HashSet { + // Iterate the smaller set, removing elements that are in `rhs` from `self` + if self.len() <= rhs.len() { + self.retain(|e| !rhs.contains(e)); + } else { + rhs.iter().for_each(|e| { + self.remove(e); + }) + } + + self + } +} + /// An iterator over the items of a `HashSet`. /// /// This `struct` is created by the [`iter`] method on [`HashSet`]. @@ -1913,3 +2014,7 @@ fn assert_covariance() { d } } + +fn minmax_by_key(a: T, b: T, k: impl Fn(&T) -> K) -> [T; 2] { + if k(&a) <= k(&b) { [a, b] } else { [b, a] } +}