Skip to content

Commit b7373aa

Browse files
committed
refactor advance_by and advance_back_by. Add back cfg for with_critical_section
1 parent 00e4802 commit b7373aa

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

src/types/list.rs

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,14 @@ impl<'py> BoundListIterator<'py> {
640640
list.get_item(target_index).expect("get-item failed")
641641
}
642642
};
643-
*length = Length(target_index);
643+
length.0 = target_index;
644644
Some(item)
645645
} else {
646646
None
647647
}
648648
}
649649

650+
#[cfg(not(Py_LIMITED_API))]
650651
fn with_critical_section<R>(
651652
&mut self,
652653
f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R,
@@ -818,12 +819,25 @@ impl<'py> Iterator for BoundListIterator<'py> {
818819
#[cfg(feature = "nightly")]
819820
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
820821
self.with_critical_section(|index, length, list| {
821-
for i in 0..n {
822-
if unsafe { Self::next_unchecked(index, length, list).is_none() } {
823-
return Err(unsafe { NonZero::new_unchecked(n - i) });
822+
let max_len = length.0.min(list.len());
823+
let currently_at = index.0;
824+
if currently_at >= max_len {
825+
if n == 0 {
826+
return Ok(());
827+
} else {
828+
return Err(unsafe { NonZero::new_unchecked(n) });
824829
}
825830
}
826-
Ok(())
831+
832+
let items_left = max_len - currently_at;
833+
if n <= items_left {
834+
index.0 += n;
835+
Ok(())
836+
} else {
837+
index.0 = max_len;
838+
let remainder = n - items_left;
839+
Err(unsafe { NonZero::new_unchecked(remainder) })
840+
}
827841
})
828842
}
829843
}
@@ -891,12 +905,25 @@ impl DoubleEndedIterator for BoundListIterator<'_> {
891905
#[cfg(feature = "nightly")]
892906
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
893907
self.with_critical_section(|index, length, list| {
894-
for i in 0..n {
895-
if unsafe { Self::next_back_unchecked(index, length, list).is_none() } {
896-
return Err(unsafe { NonZero::new_unchecked(n - i) });
908+
let max_len = length.0.min(list.len());
909+
let currently_at = index.0;
910+
if currently_at >= max_len {
911+
if n == 0 {
912+
return Ok(());
913+
} else {
914+
return Err(unsafe { NonZero::new_unchecked(n) });
897915
}
898916
}
899-
Ok(())
917+
918+
let items_left = max_len - currently_at;
919+
if n <= items_left {
920+
length.0 = max_len - n;
921+
Ok(())
922+
} else {
923+
length.0 = max_len;
924+
let remainder = n - items_left;
925+
Err(unsafe { NonZero::new_unchecked(remainder) })
926+
}
900927
})
901928
}
902929
}
@@ -1637,8 +1664,7 @@ mod tests {
16371664
assert_eq!(iter.next().unwrap().extract::<i32>().unwrap(), 10);
16381665

16391666
let mut iter = list.iter();
1640-
println!("iter.nth_back(1) = {:?}", iter.nth_back(1));
1641-
// assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 9);
1667+
assert_eq!(iter.nth_back(1).unwrap().extract::<i32>().unwrap(), 9);
16421668
assert_eq!(iter.nth(2).unwrap().extract::<i32>().unwrap(), 8);
16431669
assert!(iter.next().is_none());
16441670
});

0 commit comments

Comments
 (0)