diff --git a/library/core/src/iter/adapters/zip.rs b/library/core/src/iter/adapters/zip.rs index 8153c8cfef133..fea3b854a5814 100644 --- a/library/core/src/iter/adapters/zip.rs +++ b/library/core/src/iter/adapters/zip.rs @@ -2,6 +2,7 @@ use crate::cmp; use crate::fmt::{self, Debug}; use crate::iter::{DoubleEndedIterator, ExactSizeIterator, FusedIterator, Iterator}; use crate::iter::{InPlaceIterable, SourceIter, TrustedLen}; +use crate::ops::{ControlFlow, NeverShortCircuit, Try}; /// An iterator that iterates two other iterators simultaneously. /// @@ -31,6 +32,23 @@ impl Zip { } None } + #[inline] + fn adjust_back(&mut self) + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + { + let a_sz = self.a.len(); + let b_sz = self.b.len(); + if a_sz != b_sz { + // Adjust a, b to equal length + if a_sz > b_sz { + let _ = self.a.advance_back_by(a_sz - b_sz); + } else { + let _ = self.b.advance_back_by(b_sz - a_sz); + } + } + } } /// Converts the arguments to iterators and zips them. @@ -94,6 +112,23 @@ where ZipImpl::nth(self, n) } + #[inline] + fn fold(self, init: T, f: F) -> T + where + F: FnMut(T, Self::Item) -> T, + { + ZipImpl::fold(self, init, f) + } + + #[inline] + fn try_fold(&mut self, init: T, f: F) -> R + where + F: FnMut(T, Self::Item) -> R, + R: Try, + { + ZipImpl::try_fold(self, init, f) + } + #[inline] unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item where @@ -115,6 +150,23 @@ where fn next_back(&mut self) -> Option<(A::Item, B::Item)> { ZipImpl::next_back(self) } + + #[inline] + fn rfold(self, init: T, f: F) -> T + where + F: FnMut(T, Self::Item) -> T, + { + ZipImpl::rfold(self, init, f) + } + + #[inline] + fn try_rfold(&mut self, init: T, f: F) -> R + where + F: FnMut(T, Self::Item) -> R, + R: Try, + { + ZipImpl::try_rfold(self, init, f) + } } // Zip specialization trait @@ -129,12 +181,71 @@ trait ZipImpl { where A: DoubleEndedIterator + ExactSizeIterator, B: DoubleEndedIterator + ExactSizeIterator; + fn fold(self, init: T, f: F) -> T + where + F: FnMut(T, Self::Item) -> T; + fn try_fold(&mut self, init: T, f: F) -> R + where + F: FnMut(T, Self::Item) -> R, + R: Try; + fn rfold(self, init: T, f: F) -> T + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> T; + fn try_rfold(&mut self, init: T, f: F) -> R + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> R, + R: Try; + // This has the same safety requirements as `Iterator::__iterator_get_unchecked` unsafe fn get_unchecked(&mut self, idx: usize) -> ::Item where Self: Iterator + TrustedRandomAccessNoCoerce; } +#[inline] +fn check_rfold T>( + mut b: B, + mut f: F, +) -> impl FnMut(T, AItem) -> T { + move |acc, x| match b.next_back() { + Some(y) => f(acc, (x, y)), + None => panic!("Iterator expected Some(item), found None."), + } +} + +#[inline] +fn check_try_fold<'b, AItem, BItem, T, R: Try>( + b: &'b mut impl Iterator, + mut f: impl 'b + FnMut(T, (AItem, BItem)) -> R, +) -> impl 'b + FnMut(T, AItem) -> ControlFlow { + move |acc, x| match b.next() { + Some(y) => ControlFlow::from_try(f(acc, (x, y))), + None => ControlFlow::Break(R::from_output(acc)), + } +} + +#[inline] +fn check_try_rfold< + 'b, + AItem, + B: DoubleEndedIterator, + T, + R: Try, + F: 'b + FnMut(T, (AItem, B::Item)) -> R, +>( + b: &'b mut B, + mut f: F, +) -> impl 'b + FnMut(T, AItem) -> ControlFlow { + move |acc, x| match b.next_back() { + Some(y) => ControlFlow::from_try(f(acc, (x, y))), + None => ControlFlow::Break(R::from_output(acc)), + } +} + // Work around limitations of specialization, requiring `default` impls to be repeated // in intermediary impls. macro_rules! zip_impl_general_defaults { @@ -171,26 +282,53 @@ macro_rules! zip_impl_general_defaults { // and doesn’t call `next_back` too often, so this implementation is safe in // the `TrustedRandomAccessNoCoerce` specialization - let a_sz = self.a.len(); - let b_sz = self.b.len(); - if a_sz != b_sz { - // Adjust a, b to equal length - if a_sz > b_sz { - for _ in 0..a_sz - b_sz { - self.a.next_back(); - } - } else { - for _ in 0..b_sz - a_sz { - self.b.next_back(); - } - } - } + self.adjust_back(); match (self.a.next_back(), self.b.next_back()) { (Some(x), Some(y)) => Some((x, y)), (None, None) => None, _ => unreachable!(), } } + + #[inline] + default fn fold(mut self, init: T, f: F) -> T + where + F: FnMut(T, Self::Item) -> T, + { + ZipImpl::try_fold(&mut self, init, NeverShortCircuit::wrap_mut_2(f)).0 + } + + #[inline] + default fn try_fold(&mut self, init: T, f: F) -> R + where + F: FnMut(T, Self::Item) -> R, + R: Try, + { + self.a.try_fold(init, check_try_fold(&mut self.b, f)).into_try() + } + + #[inline] + default fn rfold(mut self, init: T, f: F) -> T + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> T, + { + self.adjust_back(); + self.a.rfold(init, check_rfold(self.b, f)) + } + + #[inline] + default fn try_rfold(&mut self, init: T, f: F) -> R + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> R, + R: Try, + { + self.adjust_back(); + self.a.try_rfold(init, check_try_rfold(&mut self.b, f)).into_try() + } }; } @@ -253,6 +391,40 @@ where } } +/// Adjusts a, b to equal length. Makes sure that only the first call +/// of `next_back` does this, otherwise we will break the restriction +/// on calls to `zipped.next_back()` after calling `get_unchecked()`. +#[inline] +fn adjust_back_trusted_random_access< + A: TrustedRandomAccess + DoubleEndedIterator + ExactSizeIterator, + B: TrustedRandomAccess + DoubleEndedIterator + ExactSizeIterator, +>( + zipped: &mut Zip, +) { + if A::MAY_HAVE_SIDE_EFFECT || B::MAY_HAVE_SIDE_EFFECT { + let sz_a = zipped.a.size(); + let sz_b = zipped.b.size(); + if sz_a != sz_b { + let sz_a = zipped.a.size(); + if A::MAY_HAVE_SIDE_EFFECT && sz_a > zipped.len { + for _ in 0..sz_a - zipped.len { + // since next_back() may panic we increment the counters beforehand + // to keep Zip's state in sync with the underlying iterator source + zipped.a_len -= 1; + zipped.a.next_back(); + } + debug_assert_eq!(zipped.a_len, zipped.len); + } + let sz_b = zipped.b.size(); + if B::MAY_HAVE_SIDE_EFFECT && sz_b > zipped.len { + for _ in 0..sz_b - zipped.len { + zipped.b.next_back(); + } + } + } + } +} + #[doc(hidden)] impl ZipImpl for Zip where @@ -332,31 +504,7 @@ where A: DoubleEndedIterator + ExactSizeIterator, B: DoubleEndedIterator + ExactSizeIterator, { - if A::MAY_HAVE_SIDE_EFFECT || B::MAY_HAVE_SIDE_EFFECT { - let sz_a = self.a.size(); - let sz_b = self.b.size(); - // Adjust a, b to equal length, make sure that only the first call - // of `next_back` does this, otherwise we will break the restriction - // on calls to `self.next_back()` after calling `get_unchecked()`. - if sz_a != sz_b { - let sz_a = self.a.size(); - if A::MAY_HAVE_SIDE_EFFECT && sz_a > self.len { - for _ in 0..sz_a - self.len { - // since next_back() may panic we increment the counters beforehand - // to keep Zip's state in sync with the underlying iterator source - self.a_len -= 1; - self.a.next_back(); - } - debug_assert_eq!(self.a_len, self.len); - } - let sz_b = self.b.size(); - if B::MAY_HAVE_SIDE_EFFECT && sz_b > self.len { - for _ in 0..sz_b - self.len { - self.b.next_back(); - } - } - } - } + adjust_back_trusted_random_access(self); if self.index < self.len { // since get_unchecked executes code which can panic we increment the counters beforehand // so that the same index won't be accessed twice, as required by TrustedRandomAccess @@ -372,6 +520,108 @@ where None } } + + #[inline] + fn fold(mut self, init: T, mut f: F) -> T + where + F: FnMut(T, Self::Item) -> T, + { + // we don't need to adjust `self.{index, len}` since we have ownership of the iterator + let mut accum = init; + let index = self.index; + let len = self.len; + for i in index..len { + // SAFETY: `i` is smaller than `self.len`, thus smaller than `self.a.len()` and `self.b.len()` + accum = unsafe { + f(accum, (self.a.__iterator_get_unchecked(i), self.b.__iterator_get_unchecked(i))) + }; + } + if A::MAY_HAVE_SIDE_EFFECT && len < self.a_len { + // SAFETY: `i` is smaller than `self.a_len`, which is equal to `self.a.len()` + unsafe { self.a.__iterator_get_unchecked(len) }; + } + accum + } + + #[inline] + fn try_fold(&mut self, init: T, mut f: F) -> R + where + F: FnMut(T, Self::Item) -> R, + R: Try, + { + let index = self.index; + let len = self.len; + + let mut accum = init; + + for i in index..len { + // adjust `self.index` beforehand in case once of the iterators panics. + self.index = i + 1; + // SAFETY: `i` is smaller than `self.len`, thus smaller than `self.a.len()` and `self.b.len()` + accum = unsafe { + f(accum, (self.a.__iterator_get_unchecked(i), self.b.__iterator_get_unchecked(i)))? + }; + } + + if A::MAY_HAVE_SIDE_EFFECT && len < self.a_len { + self.index = len + 1; + self.len = len + 1; + // SAFETY: `i` is smaller than `self.a_len`, which is equal to `self.a.len()` + unsafe { self.a.__iterator_get_unchecked(len) }; + } + + try { accum } + } + + #[inline] + fn rfold(mut self, init: T, mut f: F) -> T + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> T, + { + // we don't need to adjust `self.{len, a_len}` since we have ownership of the iterator + adjust_back_trusted_random_access(&mut self); + + let mut accum = init; + let index = self.index; + let len = self.len; + for i in 0..len - index { + let i = len - i - 1; + // SAFETY: `i` is smaller than `self.len`, thus smaller than `self.a.len()` and `self.b.len()` + accum = unsafe { + f(accum, (self.a.__iterator_get_unchecked(i), self.b.__iterator_get_unchecked(i))) + }; + } + accum + } + + #[inline] + fn try_rfold(&mut self, init: T, mut f: F) -> R + where + A: DoubleEndedIterator + ExactSizeIterator, + B: DoubleEndedIterator + ExactSizeIterator, + F: FnMut(T, Self::Item) -> R, + R: Try, + { + adjust_back_trusted_random_access(self); + + let mut accum = init; + let index = self.index; + let len = self.len; + for i in 0..len - index { + // inner `i` goes backwards from `len` (exclusive) to `index` (inclusive) + let i = len - i - 1; + // adjust `self.{len, a_len}` beforehand in case once of the iterators panics. + self.len = i; + self.a_len -= 1; + // SAFETY: `i` is smaller than `self.len`, thus smaller than `self.a.len()` and `self.b.len()` + accum = unsafe { + f(accum, (self.a.__iterator_get_unchecked(i), self.b.__iterator_get_unchecked(i)))? + }; + } + try { accum } + } } #[stable(feature = "rust1", since = "1.0.0")] diff --git a/library/core/tests/iter/adapters/zip.rs b/library/core/tests/iter/adapters/zip.rs index 585cfbb90e40c..5d48d65e45178 100644 --- a/library/core/tests/iter/adapters/zip.rs +++ b/library/core/tests/iter/adapters/zip.rs @@ -232,6 +232,159 @@ fn test_zip_trusted_random_access_composition() { assert_eq!(z2.next().unwrap(), ((1, 1), 1)); } +#[test] +fn test_zip_trusted_random_access_fold_rfold() { + let a = [0, 1, 2, 3, 4]; + let b = [5, 6, 7, 8, 9, 10]; + + let sum = a.iter().copied().sum::() + b.iter().copied().take(a.len()).sum::(); + let zip = || a.iter().copied().zip(b.iter().copied()); + let fwd_sum = zip().fold(0, |a, (b, c)| a + b + c); + let bwd_sum = zip().rfold(0, |a, (b, c)| a + b + c); + + assert_eq!(fwd_sum, sum); + assert_eq!(bwd_sum, sum); +} + +#[test] +fn test_zip_trusted_random_access_try_fold_try_rfold() { + let a = [0, 1, 2, 3, 4]; + let b = [5, 6, 7, 8, 9, 10]; + + let sum = a.iter().copied().sum::() + b.iter().copied().take(a.len()).sum::(); + let zip = || a.iter().copied().zip(b.iter().copied()); + let mut zip_fwd = zip(); + let mut zip_bwd = zip(); + + let fwd_sum: Result = zip_fwd.try_fold(0, |a, (b, c)| Ok(a + b + c)); + let bwd_sum: Result = zip_bwd.try_rfold(0, |a, (b, c)| Ok(a + b + c)); + + assert_eq!(fwd_sum, Ok(sum)); + assert_eq!(bwd_sum, Ok(sum)); + assert_eq!(zip_fwd.next(), None); + assert_eq!(zip_fwd.next_back(), None); + assert_eq!(zip_bwd.next(), None); + assert_eq!(zip_bwd.next_back(), None); +} + +#[test] +fn test_zip_trusted_random_access_try_fold_try_rfold_resumable() { + let a = [0, 1, 2, 3, 4]; + let b = [5, 6, 7, 8, 9, 10]; + + fn sum_countdown(mut count: usize) -> impl FnMut(i32, (i32, i32)) -> Result { + move |a: i32, (b, c): (i32, i32)| { + if count == 0 { + Err(a) + } else { + count -= 1; + Ok(a + b + c) + } + } + } + + let zip = || a.iter().copied().zip(b.iter().copied()); + let mut zip_fwd = zip(); + let mut zip_bwd = zip(); + + let fwd_sum = zip_fwd.try_fold(0, sum_countdown(2)); + let bwd_sum = zip_bwd.try_rfold(0, sum_countdown(2)); + + assert_eq!(fwd_sum, Err(0 + 1 + 5 + 6)); + assert_eq!(bwd_sum, Err(4 + 3 + 9 + 8)); + { + let mut zip_fwd = zip_fwd.clone(); + let mut zip_bwd = zip_bwd.clone(); + assert_eq!(zip_fwd.next(), Some((3, 8))); + assert_eq!(zip_fwd.next(), Some((4, 9))); + assert_eq!(zip_fwd.next(), None); + + assert_eq!(zip_bwd.next(), Some((0, 5))); + assert_eq!(zip_bwd.next(), Some((1, 6))); + assert_eq!(zip_bwd.next(), None); + } + + assert_eq!(zip_fwd.next_back(), Some((4, 9))); + assert_eq!(zip_fwd.next_back(), Some((3, 8))); + assert_eq!(zip_fwd.next_back(), None); + + assert_eq!(zip_bwd.next_back(), Some((1, 6))); + assert_eq!(zip_bwd.next_back(), Some((0, 5))); + assert_eq!(zip_bwd.next_back(), None); +} + +#[test] +#[cfg(panic = "unwind")] +fn test_zip_trusted_random_access_try_fold_try_rfold_panic() { + use std::panic::catch_unwind; + use std::panic::AssertUnwindSafe; + + let a = [0, 1, 2]; + let b = [3, 4, 5, 6]; + + fn sum_countdown(mut count: usize) -> impl FnMut(i32, (i32, i32)) -> Result { + move |a: i32, (b, c): (i32, i32)| { + if count == 0 { + panic!("bomb") + } else { + count -= 1; + Ok(a + b + c) + } + } + } + + let zip = || a.iter().copied().zip(b.iter().copied()); + let mut zip_fwd = zip(); + let mut zip_bwd = zip(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let _ = zip_fwd.try_fold(0, sum_countdown(1)); + })); + let _ = catch_unwind(AssertUnwindSafe(|| { + let _ = zip_bwd.try_rfold(0, sum_countdown(1)); + })); + + { + let mut zip_fwd = zip_fwd.clone(); + let mut zip_bwd = zip_bwd.clone(); + match zip_fwd.next() { + Some((a, b)) => { + assert!(a > 1); + assert!(b > 4); + assert_eq!(b - a, 3); + } + None => (), + }; + + match zip_bwd.next_back() { + Some((a, b)) => { + assert!(a < 1); + assert!(b < 4); + assert_eq!(b - a, 3); + } + None => (), + }; + } + + match zip_fwd.next_back() { + Some((a, b)) => { + assert!(a > 1); + assert!(b > 4); + assert_eq!(b - a, 3); + } + None => (), + }; + + match zip_bwd.next() { + Some((a, b)) => { + assert!(a < 1); + assert!(b < 4); + assert_eq!(b - a, 3); + } + None => (), + }; +} + #[test] #[cfg(panic = "unwind")] fn test_zip_trusted_random_access_next_back_drop() {