Skip to content

Commit 2c6e671

Browse files
committed
implement advance_(back_)_by on more iterators
1 parent 6dc08b9 commit 2c6e671

File tree

15 files changed

+376
-3
lines changed

15 files changed

+376
-3
lines changed

library/alloc/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
// that the feature-gate isn't enabled. Ideally, it wouldn't check for the feature gate for docs
112112
// from other crates, but since this can only appear for lang items, it doesn't seem worth fixing.
113113
#![feature(intra_doc_pointers)]
114+
#![feature(iter_advance_by)]
114115
#![feature(iter_zip)]
115116
#![feature(lang_items)]
116117
#![feature(layout_for_ptr)]

library/alloc/src/vec/into_iter.rs

+45
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,28 @@ impl<T, A: Allocator> Iterator for IntoIter<T, A> {
161161
(exact, Some(exact))
162162
}
163163

164+
#[inline]
165+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
166+
let step_size = self.len().min(n);
167+
if mem::size_of::<T>() == 0 {
168+
// SAFETY: due to unchecked casts of unsigned amounts to signed offsets the wraparound
169+
// effectively results in unsigned pointers representing positions 0..usize::MAX,
170+
// which is valid for ZSTs.
171+
self.ptr = unsafe { arith_offset(self.ptr as *const i8, step_size as isize) as *mut T }
172+
} else {
173+
let to_drop = ptr::slice_from_raw_parts_mut(self.ptr as *mut T, step_size);
174+
// SAFETY: the min() above ensures that step_size is in bounds
175+
unsafe {
176+
self.ptr = self.ptr.add(step_size);
177+
ptr::drop_in_place(to_drop);
178+
}
179+
}
180+
if step_size < n {
181+
return Err(step_size);
182+
}
183+
Ok(())
184+
}
185+
164186
#[inline]
165187
fn count(self) -> usize {
166188
self.len()
@@ -203,6 +225,29 @@ impl<T, A: Allocator> DoubleEndedIterator for IntoIter<T, A> {
203225
Some(unsafe { ptr::read(self.end) })
204226
}
205227
}
228+
229+
#[inline]
230+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
231+
let step_size = self.len().min(n);
232+
if mem::size_of::<T>() == 0 {
233+
// SAFETY: same as for advance_by()
234+
self.end = unsafe {
235+
arith_offset(self.end as *const i8, step_size.wrapping_neg() as isize) as *mut T
236+
}
237+
} else {
238+
// SAFETY: same as for advance_by()
239+
self.end = unsafe { self.end.offset(step_size.wrapping_neg() as isize) };
240+
let to_drop = ptr::slice_from_raw_parts_mut(self.end as *mut T, step_size);
241+
// SAFETY: same as for advance_by()
242+
unsafe {
243+
ptr::drop_in_place(to_drop);
244+
}
245+
}
246+
if step_size < n {
247+
return Err(step_size);
248+
}
249+
Ok(())
250+
}
206251
}
207252

208253
#[stable(feature = "rust1", since = "1.0.0")]

library/alloc/tests/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#![feature(binary_heap_retain)]
1919
#![feature(binary_heap_as_slice)]
2020
#![feature(inplace_iteration)]
21+
#![feature(iter_advance_by)]
2122
#![feature(slice_group_by)]
2223
#![feature(slice_partition_dedup)]
2324
#![feature(vec_spare_capacity)]

library/alloc/tests/vec.rs

+18
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,24 @@ fn test_into_iter_leak() {
970970
assert_eq!(unsafe { DROPS }, 3);
971971
}
972972

973+
#[test]
974+
fn test_into_iter_advance_by() {
975+
let mut i = vec![1, 2, 3, 4, 5].into_iter();
976+
i.advance_by(0).unwrap();
977+
i.advance_back_by(0).unwrap();
978+
assert_eq!(i.as_slice(), [1, 2, 3, 4, 5]);
979+
980+
i.advance_by(1).unwrap();
981+
i.advance_back_by(1).unwrap();
982+
assert_eq!(i.as_slice(), [2, 3, 4]);
983+
984+
assert_eq!(i.advance_back_by(usize::MAX), Err(3));
985+
986+
assert_eq!(i.advance_by(usize::MAX), Err(0));
987+
988+
assert_eq!(i.len(), 0);
989+
}
990+
973991
#[test]
974992
fn test_from_iter_specialization() {
975993
let src: Vec<usize> = vec![0usize; 1];

library/core/src/iter/adapters/copied.rs

+10
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ where
7676
self.it.count()
7777
}
7878

79+
#[inline]
80+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
81+
self.it.advance_by(n)
82+
}
83+
7984
#[doc(hidden)]
8085
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> T
8186
where
@@ -112,6 +117,11 @@ where
112117
{
113118
self.it.rfold(init, copy_fold(f))
114119
}
120+
121+
#[inline]
122+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
123+
self.it.advance_back_by(n)
124+
}
115125
}
116126

117127
#[stable(feature = "iter_copied", since = "1.36.0")]

library/core/src/iter/adapters/cycle.rs

+21
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ where
7979
}
8080
}
8181

82+
#[inline]
83+
#[rustc_inherit_overflow_checks]
84+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
85+
let mut rem = n;
86+
match self.iter.advance_by(rem) {
87+
ret @ Ok(_) => return ret,
88+
Err(advanced) => rem -= advanced,
89+
}
90+
91+
while rem > 0 {
92+
self.iter = self.orig.clone();
93+
match self.iter.advance_by(rem) {
94+
ret @ Ok(_) => return ret,
95+
Err(0) => return Err(n - rem),
96+
Err(advanced) => rem -= advanced,
97+
}
98+
}
99+
100+
Ok(())
101+
}
102+
82103
// No `fold` override, because `fold` doesn't make much sense for `Cycle`,
83104
// and we can't do anything better than the default.
84105
}

library/core/src/iter/adapters/enumerate.rs

+22
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ where
112112
self.iter.fold(init, enumerate(self.count, fold))
113113
}
114114

115+
#[inline]
116+
#[rustc_inherit_overflow_checks]
117+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
118+
match self.iter.advance_by(n) {
119+
ret @ Ok(_) => {
120+
self.count += n;
121+
ret
122+
}
123+
ret @ Err(advanced) => {
124+
self.count += advanced;
125+
ret
126+
}
127+
}
128+
}
129+
115130
#[rustc_inherit_overflow_checks]
116131
#[doc(hidden)]
117132
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
@@ -191,6 +206,13 @@ where
191206
let count = self.count + self.iter.len();
192207
self.iter.rfold(init, enumerate(count, fold))
193208
}
209+
210+
#[inline]
211+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
212+
// we do not need to update the count since that only tallies the number of items
213+
// consumed from the front. consuming items from the back can never reduce that.
214+
self.iter.advance_back_by(n)
215+
}
194216
}
195217

196218
#[stable(feature = "rust1", since = "1.0.0")]

library/core/src/iter/adapters/flatten.rs

+69
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,40 @@ where
391391

392392
init
393393
}
394+
395+
#[inline]
396+
#[rustc_inherit_overflow_checks]
397+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
398+
let mut rem = n;
399+
loop {
400+
if let Some(ref mut front) = self.frontiter {
401+
match front.advance_by(rem) {
402+
ret @ Ok(_) => return ret,
403+
Err(advanced) => rem -= advanced,
404+
}
405+
}
406+
self.frontiter = match self.iter.next() {
407+
Some(iterable) => Some(iterable.into_iter()),
408+
_ => break,
409+
}
410+
}
411+
412+
self.frontiter = None;
413+
414+
if let Some(ref mut back) = self.backiter {
415+
if let Err(advanced) = back.advance_by(rem) {
416+
rem -= advanced
417+
}
418+
}
419+
420+
if rem > 0 {
421+
return Err(n - rem);
422+
}
423+
424+
self.backiter = None;
425+
426+
Ok(())
427+
}
394428
}
395429

396430
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
@@ -486,6 +520,41 @@ where
486520

487521
init
488522
}
523+
524+
#[inline]
525+
#[rustc_inherit_overflow_checks]
526+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
527+
let mut rem = n;
528+
loop {
529+
if let Some(ref mut back) = self.backiter {
530+
match back.advance_back_by(rem) {
531+
ret @ Ok(_) => return ret,
532+
Err(advanced) => rem -= advanced,
533+
}
534+
}
535+
match self.iter.next_back() {
536+
Some(iterable) => self.backiter = Some(iterable.into_iter()),
537+
_ => break,
538+
}
539+
}
540+
541+
self.backiter = None;
542+
543+
if let Some(ref mut front) = self.frontiter {
544+
match front.advance_back_by(rem) {
545+
ret @ Ok(_) => return ret,
546+
Err(advanced) => rem -= advanced,
547+
}
548+
}
549+
550+
if rem > 0 {
551+
return Err(n - rem);
552+
}
553+
554+
self.frontiter = None;
555+
556+
Ok(())
557+
}
489558
}
490559

491560
trait ConstSizeIntoIterator: IntoIterator {

library/core/src/iter/adapters/skip.rs

+21
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ where
114114
}
115115
self.iter.fold(init, fold)
116116
}
117+
118+
#[inline]
119+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
120+
if self.n >= n {
121+
self.n -= n;
122+
return Ok(());
123+
}
124+
let rem = n - self.n;
125+
self.n = 0;
126+
self.iter.advance_by(rem)
127+
}
117128
}
118129

119130
#[stable(feature = "rust1", since = "1.0.0")]
@@ -174,6 +185,16 @@ where
174185

175186
self.try_rfold(init, ok(fold)).unwrap()
176187
}
188+
189+
#[inline]
190+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
191+
let min = crate::cmp::min(self.len(), n);
192+
return match self.iter.advance_back_by(min) {
193+
ret @ Ok(_) if n <= min => ret,
194+
Ok(_) => Err(min),
195+
_ => panic!("ExactSizeIterator contract violation"),
196+
};
197+
}
177198
}
178199

179200
#[stable(feature = "fused", since = "1.26.0")]

library/core/src/iter/adapters/take.rs

+34
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ where
111111

112112
self.try_fold(init, ok(fold)).unwrap()
113113
}
114+
115+
#[inline]
116+
#[rustc_inherit_overflow_checks]
117+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
118+
let min = crate::cmp::min(self.n, n);
119+
return match self.iter.advance_by(min) {
120+
Ok(_) => {
121+
self.n -= min;
122+
if min < n { Err(min) } else { Ok(()) }
123+
}
124+
ret @ Err(advanced) => {
125+
self.n -= advanced;
126+
ret
127+
}
128+
};
129+
}
114130
}
115131

116132
#[unstable(issue = "none", feature = "inplace_iteration")]
@@ -197,6 +213,24 @@ where
197213
}
198214
}
199215
}
216+
217+
#[inline]
218+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
219+
let inner_len = self.iter.len();
220+
let len = self.n;
221+
let remainder = len.saturating_sub(n);
222+
let to_advance = inner_len - remainder;
223+
match self.iter.advance_back_by(to_advance) {
224+
Ok(_) => {
225+
self.n = remainder;
226+
if n > len {
227+
return Err(len);
228+
}
229+
return Ok(());
230+
}
231+
_ => panic!("ExactSizeIterator contract violation"),
232+
}
233+
}
200234
}
201235

202236
#[stable(feature = "rust1", since = "1.0.0")]

0 commit comments

Comments
 (0)