diff --git a/src/types/list.rs b/src/types/list.rs index e5ac643ca5f..f7b4597c11e 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -483,17 +483,27 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> { } } +// New types for type checking when using BoundListIterator associated methods, like +// BoundListIterator::next_unchecked. +struct Index(usize); +struct Length(usize); + /// Used by `PyList::iter()`. pub struct BoundListIterator<'py> { list: Bound<'py, PyList>, - inner: ListIterImpl, + index: Index, + length: Length, } -enum ListIterImpl { - ListIter { index: usize, length: usize }, -} +impl<'py> BoundListIterator<'py> { + fn new(list: Bound<'py, PyList>) -> Self { + Self { + index: Index(0), + length: Length(list.len()), + list, + } + } -impl ListIterImpl { #[inline] /// # Safety /// @@ -501,41 +511,38 @@ impl ListIterImpl { /// access to the list by holding a lock or by holding the innermost /// critical section on the list. #[cfg(all(not(Py_LIMITED_API), not(PyPy)))] - unsafe fn next_unchecked<'py>( - &mut self, + unsafe fn next_unchecked( + index: &mut Index, + length: &mut Length, list: &Bound<'py, PyList>, ) -> Option> { - match self { - Self::ListIter { index, length, .. } => { - let length = (*length).min(list.len()); - let my_index = *index; - - if *index < length { - let item = unsafe { list.get_item_unchecked(my_index) }; - *index += 1; - Some(item) - } else { - None - } - } + let length = length.0.min(list.len()); + let my_index = index.0; + + if index.0 < length { + let item = unsafe { list.get_item_unchecked(my_index) }; + index.0 += 1; + Some(item) + } else { + None } } #[cfg(any(Py_LIMITED_API, PyPy))] - fn next<'py>(&mut self, list: &Bound<'py, PyList>) -> Option> { - match self { - Self::ListIter { index, length, .. } => { - let length = (*length).min(list.len()); - let my_index = *index; - - if *index < length { - let item = list.get_item(my_index).expect("get-item failed"); - *index += 1; - Some(item) - } else { - None - } - } + fn next( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let length = length.0.min(list.len()); + let my_index = index.0; + + if index.0 < length { + let item = list.get_item(my_index).expect("get-item failed"); + index.0 += 1; + Some(item) + } else { + None } } @@ -546,68 +553,37 @@ impl ListIterImpl { /// critical section on the list. #[inline] #[cfg(all(not(Py_LIMITED_API), not(PyPy)))] - unsafe fn next_back_unchecked<'py>( - &mut self, + unsafe fn next_back_unchecked( + index: &mut Index, + length: &mut Length, list: &Bound<'py, PyList>, ) -> Option> { - match self { - Self::ListIter { index, length, .. } => { - let current_length = (*length).min(list.len()); - - if *index < current_length { - let item = unsafe { list.get_item_unchecked(current_length - 1) }; - *length = current_length - 1; - Some(item) - } else { - None - } - } - } - } + let current_length = length.0.min(list.len()); - #[inline] - #[cfg(any(Py_LIMITED_API, PyPy))] - fn next_back<'py>(&mut self, list: &Bound<'py, PyList>) -> Option> { - match self { - Self::ListIter { index, length, .. } => { - let current_length = (*length).min(list.len()); - - if *index < current_length { - let item = list.get_item(current_length - 1).expect("get-item failed"); - *length = current_length - 1; - Some(item) - } else { - None - } - } - } - } - - #[inline] - fn len(&self) -> usize { - match self { - Self::ListIter { index, length, .. } => length.saturating_sub(*index), + if index.0 < current_length { + let item = unsafe { list.get_item_unchecked(current_length - 1) }; + length.0 = current_length - 1; + Some(item) + } else { + None } } - #[cfg(Py_GIL_DISABLED)] #[inline] - fn with_critical_section(&mut self, list: &Bound<'_, PyList>, f: F) -> R - where - F: FnOnce(&mut Self) -> R, - { - match self { - Self::ListIter { .. } => crate::sync::with_critical_section(list, || f(self)), - } - } -} + #[cfg(any(Py_LIMITED_API, PyPy))] + fn next_back( + index: &mut Index, + length: &mut Length, + list: &Bound<'py, PyList>, + ) -> Option> { + let current_length = (length.0).min(list.len()); -impl<'py> BoundListIterator<'py> { - fn new(list: Bound<'py, PyList>) -> Self { - let length: usize = list.len(); - BoundListIterator { - list, - inner: ListIterImpl::ListIter { index: 0, length }, + if index.0 < current_length { + let item = list.get_item(current_length - 1).expect("get-item failed"); + length.0 = current_length - 1; + Some(item) + } else { + None } } } @@ -617,20 +593,34 @@ impl<'py> Iterator for BoundListIterator<'py> { #[inline] fn next(&mut self) -> Option { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + #[cfg(Py_GIL_DISABLED)] { - self.inner - .with_critical_section(&self.list, |inner| unsafe { - inner.next_unchecked(&self.list) - }) + crate::sync::with_critical_section(list, || unsafe { + Self::next_unchecked(index, length, list) + }) } #[cfg(any(Py_LIMITED_API, PyPy))] { - self.inner.next(&self.list) + let length = length.0.min(list.len()); + let my_index = index.0; + + if index.0 < length { + let item = list.get_item(my_index).expect("get-item failed"); + index.0 += 1; + Some(item) + } else { + None + } } #[cfg(all(not(Py_GIL_DISABLED), not(Py_LIMITED_API), not(PyPy)))] { - unsafe { self.inner.next_unchecked(&self.list) } + unsafe { Self::next_unchecked(index, length, list) } } } @@ -647,9 +637,15 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, F: FnMut(B, Self::Item) -> B, { - self.inner.with_critical_section(&self.list, |inner| { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { let mut accum = init; - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { accum = f(accum, x); } accum @@ -664,9 +660,15 @@ impl<'py> Iterator for BoundListIterator<'py> { F: FnMut(B, Self::Item) -> R, R: std::ops::Try, { - self.inner.with_critical_section(&self.list, |inner| { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { let mut accum = init; - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { accum = f(accum, x)? } R::from_output(accum) @@ -680,8 +682,14 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> bool, { - self.inner.with_critical_section(&self.list, |inner| { - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { if !f(x) { return false; } @@ -697,8 +705,14 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> bool, { - self.inner.with_critical_section(&self.list, |inner| { - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { if f(x) { return true; } @@ -714,8 +728,14 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, P: FnMut(&Self::Item) -> bool, { - self.inner.with_critical_section(&self.list, |inner| { - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { if predicate(&x) { return Some(x); } @@ -731,8 +751,14 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, F: FnMut(Self::Item) -> Option, { - self.inner.with_critical_section(&self.list, |inner| { - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { if let found @ Some(_) = f(x) { return found; } @@ -748,9 +774,15 @@ impl<'py> Iterator for BoundListIterator<'py> { Self: Sized, P: FnMut(Self::Item) -> bool, { - self.inner.with_critical_section(&self.list, |inner| { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { let mut acc = 0; - while let Some(x) = unsafe { inner.next_unchecked(&self.list) } { + while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } { if predicate(x) { return Some(acc); } @@ -764,20 +796,25 @@ impl<'py> Iterator for BoundListIterator<'py> { impl DoubleEndedIterator for BoundListIterator<'_> { #[inline] fn next_back(&mut self) -> Option { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + #[cfg(Py_GIL_DISABLED)] { - self.inner - .with_critical_section(&self.list, |inner| unsafe { - inner.next_back_unchecked(&self.list) - }) + crate::sync::with_critical_section(list, || unsafe { + Self::next_back_unchecked(index, length, list) + }) } #[cfg(any(Py_LIMITED_API, PyPy))] { - self.inner.next_back(&self.list) + Self::next_back(index, length, list) } #[cfg(all(not(Py_GIL_DISABLED), not(Py_LIMITED_API), not(PyPy)))] { - unsafe { self.inner.next_back_unchecked(&self.list) } + unsafe { Self::next_back_unchecked(index, length, list) } } } @@ -788,9 +825,15 @@ impl DoubleEndedIterator for BoundListIterator<'_> { Self: Sized, F: FnMut(B, Self::Item) -> B, { - self.inner.with_critical_section(&self.list, |inner| { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { let mut accum = init; - while let Some(x) = unsafe { inner.next_back_unchecked(&self.list) } { + while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } { accum = f(accum, x); } accum @@ -805,9 +848,15 @@ impl DoubleEndedIterator for BoundListIterator<'_> { F: FnMut(B, Self::Item) -> R, R: std::ops::Try, { - self.inner.with_critical_section(&self.list, |inner| { + let Self { + ref mut index, + ref mut length, + ref list, + } = self; + + crate::sync::with_critical_section(list, || { let mut accum = init; - while let Some(x) = unsafe { inner.next_back_unchecked(&self.list) } { + while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } { accum = f(accum, x)? } R::from_output(accum) @@ -817,7 +866,7 @@ impl DoubleEndedIterator for BoundListIterator<'_> { impl ExactSizeIterator for BoundListIterator<'_> { fn len(&self) -> usize { - self.inner.len() + self.length.0.saturating_sub(self.index.0) } }