Skip to content

Commit af95302

Browse files
committed
Rollup merge of #47552 - oberien:stepby-nth, r=dtolnay
Specialize StepBy::nth This allows optimizations of implementations of the inner iterator's `.nth` method.
2 parents 86eb725 + 4a0da4c commit af95302

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

src/libcore/iter/mod.rs

+44
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ use fmt;
307307
use iter_private::TrustedRandomAccess;
308308
use ops::Try;
309309
use usize;
310+
use intrinsics;
310311

311312
#[stable(feature = "rust1", since = "1.0.0")]
312313
pub use self::iterator::Iterator;
@@ -694,6 +695,49 @@ impl<I> Iterator for StepBy<I> where I: Iterator {
694695
(f(inner_hint.0), inner_hint.1.map(f))
695696
}
696697
}
698+
699+
#[inline]
700+
fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
701+
if self.first_take {
702+
self.first_take = false;
703+
let first = self.iter.next();
704+
if n == 0 {
705+
return first;
706+
}
707+
n -= 1;
708+
}
709+
// n and self.step are indices, we need to add 1 to get the amount of elements
710+
// When calling `.nth`, we need to subtract 1 again to convert back to an index
711+
// step + 1 can't overflow because `.step_by` sets `self.step` to `step - 1`
712+
let mut step = self.step + 1;
713+
// n + 1 could overflow
714+
// thus, if n is usize::MAX, instead of adding one, we call .nth(step)
715+
if n == usize::MAX {
716+
self.iter.nth(step - 1);
717+
} else {
718+
n += 1;
719+
}
720+
721+
// overflow handling
722+
loop {
723+
let mul = n.checked_mul(step);
724+
if unsafe { intrinsics::likely(mul.is_some()) } {
725+
return self.iter.nth(mul.unwrap() - 1);
726+
}
727+
let div_n = usize::MAX / n;
728+
let div_step = usize::MAX / step;
729+
let nth_n = div_n * n;
730+
let nth_step = div_step * step;
731+
let nth = if nth_n > nth_step {
732+
step -= div_n;
733+
nth_n
734+
} else {
735+
n -= div_step;
736+
nth_step
737+
};
738+
self.iter.nth(nth - 1);
739+
}
740+
}
697741
}
698742

699743
// StepBy can only make the iterator shorter, so the len will still fit.

src/libcore/tests/iter.rs

+62
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,68 @@ fn test_iterator_step_by() {
161161
assert_eq!(it.next(), None);
162162
}
163163

164+
#[test]
165+
fn test_iterator_step_by_nth() {
166+
let mut it = (0..16).step_by(5);
167+
assert_eq!(it.nth(0), Some(0));
168+
assert_eq!(it.nth(0), Some(5));
169+
assert_eq!(it.nth(0), Some(10));
170+
assert_eq!(it.nth(0), Some(15));
171+
assert_eq!(it.nth(0), None);
172+
173+
let it = (0..18).step_by(5);
174+
assert_eq!(it.clone().nth(0), Some(0));
175+
assert_eq!(it.clone().nth(1), Some(5));
176+
assert_eq!(it.clone().nth(2), Some(10));
177+
assert_eq!(it.clone().nth(3), Some(15));
178+
assert_eq!(it.clone().nth(4), None);
179+
assert_eq!(it.clone().nth(42), None);
180+
}
181+
182+
#[test]
183+
fn test_iterator_step_by_nth_overflow() {
184+
#[cfg(target_pointer_width = "8")]
185+
type Bigger = u16;
186+
#[cfg(target_pointer_width = "16")]
187+
type Bigger = u32;
188+
#[cfg(target_pointer_width = "32")]
189+
type Bigger = u64;
190+
#[cfg(target_pointer_width = "64")]
191+
type Bigger = u128;
192+
193+
#[derive(Clone)]
194+
struct Test(Bigger);
195+
impl<'a> Iterator for &'a mut Test {
196+
type Item = i32;
197+
fn next(&mut self) -> Option<Self::Item> { Some(21) }
198+
fn nth(&mut self, n: usize) -> Option<Self::Item> {
199+
self.0 += n as Bigger + 1;
200+
Some(42)
201+
}
202+
}
203+
204+
let mut it = Test(0);
205+
let root = usize::MAX >> (::std::mem::size_of::<usize>() * 8 / 2);
206+
let n = root + 20;
207+
(&mut it).step_by(n).nth(n);
208+
assert_eq!(it.0, n as Bigger * n as Bigger);
209+
210+
// large step
211+
let mut it = Test(0);
212+
(&mut it).step_by(usize::MAX).nth(5);
213+
assert_eq!(it.0, (usize::MAX as Bigger) * 5);
214+
215+
// n + 1 overflows
216+
let mut it = Test(0);
217+
(&mut it).step_by(2).nth(usize::MAX);
218+
assert_eq!(it.0, (usize::MAX as Bigger) * 2);
219+
220+
// n + 1 overflows
221+
let mut it = Test(0);
222+
(&mut it).step_by(1).nth(usize::MAX);
223+
assert_eq!(it.0, (usize::MAX as Bigger) * 1);
224+
}
225+
164226
#[test]
165227
#[should_panic]
166228
fn test_iterator_step_by_zero() {

0 commit comments

Comments
 (0)