From 602fbbab44b7de0e0fbfde95ba3fb1ebace3685c Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers <7943856+bschoenmaeckers@users.noreply.github.com> Date: Thu, 24 Oct 2024 09:14:03 +0200 Subject: [PATCH] Make `PyDict` iterator compatible with free-threaded build (#4439) * Iterate over dict items in `DictIterator` * Use python instead of C API to get dict items * Use plain dict iterator when not on free-threaded & not a subclass * Copy dict on free-threaded builds to prevent concurrent modifications * Add test for dict subclass iters * Implement `PyDict::locked_for_each` * Lock `BoundDictIterator::next` on each iteration * Implement locked `fold` & `try_fold` * Implement `all`,`any`,`find`,`find_map`,`position` when not on nightly * Add changelog * Use critical section wrapper * Make `dict::locked_for_each` available on all builds * Remove item() iter * Add tests for `PyDict::iter()` reducers * Add more docs to locked_for_each * Move iter implementation into inner struct --- newsfragments/4439.changed.md | 3 + src/lib.rs | 5 +- src/types/dict.rs | 404 +++++++++++++++++++++++++++++----- 3 files changed, 355 insertions(+), 57 deletions(-) create mode 100644 newsfragments/4439.changed.md diff --git a/newsfragments/4439.changed.md b/newsfragments/4439.changed.md new file mode 100644 index 00000000000..9cb01a4d2b8 --- /dev/null +++ b/newsfragments/4439.changed.md @@ -0,0 +1,3 @@ +* Make `PyDict` iterator compatible with free-threaded build +* Added `PyDict::locked_for_each` method to iterate on free-threaded builds to prevent the dict being mutated during iteration +* Iterate over `dict.items()` when dict is subclassed from `PyDict` diff --git a/src/lib.rs b/src/lib.rs index 247b42ac372..73a22e94103 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ #![warn(missing_docs)] -#![cfg_attr(feature = "nightly", feature(auto_traits, negative_impls))] +#![cfg_attr( + feature = "nightly", + feature(auto_traits, negative_impls, try_trait_v2) +)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] // Deny some lints in doctests. // Use `#[allow(...)]` locally to override. diff --git a/src/types/dict.rs b/src/types/dict.rs index 9b7d8697d20..6987e4525ec 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -1,11 +1,9 @@ -use super::PyMapping; use crate::err::{self, PyErr, PyResult}; use crate::ffi::Py_ssize_t; use crate::ffi_ptr_ext::FfiPtrExt; use crate::instance::{Borrowed, Bound}; use crate::py_result_ext::PyResultExt; -use crate::types::any::PyAnyMethods; -use crate::types::{PyAny, PyList}; +use crate::types::{PyAny, PyAnyMethods, PyList, PyMapping}; use crate::{ffi, BoundObject, IntoPyObject, Python}; /// Represents a Python `dict`. @@ -180,6 +178,19 @@ pub trait PyDictMethods<'py>: crate::sealed::Sealed { /// so long as the set of keys does not change. fn iter(&self) -> BoundDictIterator<'py>; + /// Iterates over the contents of this dictionary while holding a critical section on the dict. + /// This is useful when the GIL is disabled and the dictionary is shared between threads. + /// It is not guaranteed that the dictionary will not be modified during iteration when the + /// closure calls arbitrary Python code that releases the current critical section. + /// + /// This method is a small performance optimization over `.iter().try_for_each()` when the + /// nightly feature is not enabled because we cannot implement an optimised version of + /// `iter().try_fold()` on stable yet. If your iteration is infallible then this method has the + /// same performance as `.iter().for_each()`. + fn locked_for_each(&self, closure: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>, Bound<'py, PyAny>) -> PyResult<()>; + /// Returns `self` cast as a `PyMapping`. fn as_mapping(&self) -> &Bound<'py, PyMapping>; @@ -357,6 +368,25 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> { BoundDictIterator::new(self.clone()) } + fn locked_for_each(&self, f: F) -> PyResult<()> + where + F: Fn(Bound<'py, PyAny>, Bound<'py, PyAny>) -> PyResult<()>, + { + #[cfg(feature = "nightly")] + { + // We don't need a critical section when the nightly feature is enabled because + // try_for_each is locked by the implementation of try_fold. + self.iter().try_for_each(|(key, value)| f(key, value)) + } + + #[cfg(not(feature = "nightly"))] + { + crate::sync::with_critical_section(self, || { + self.iter().try_for_each(|(key, value)| f(key, value)) + }) + } + } + fn as_mapping(&self) -> &Bound<'py, PyMapping> { unsafe { self.downcast_unchecked() } } @@ -403,9 +433,86 @@ fn dict_len(dict: &Bound<'_, PyDict>) -> Py_ssize_t { /// PyO3 implementation of an iterator for a Python `dict` object. pub struct BoundDictIterator<'py> { dict: Bound<'py, PyDict>, - ppos: ffi::Py_ssize_t, - di_used: ffi::Py_ssize_t, - len: ffi::Py_ssize_t, + inner: DictIterImpl, +} + +enum DictIterImpl { + DictIter { + ppos: ffi::Py_ssize_t, + di_used: ffi::Py_ssize_t, + remaining: ffi::Py_ssize_t, + }, +} + +impl DictIterImpl { + #[inline] + fn next<'py>( + &mut self, + dict: &Bound<'py, PyDict>, + ) -> Option<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + match self { + Self::DictIter { + di_used, + remaining, + ppos, + .. + } => crate::sync::with_critical_section(dict, || { + let ma_used = dict_len(dict); + + // These checks are similar to what CPython does. + // + // If the dimension of the dict changes e.g. key-value pairs are removed + // or added during iteration, this will panic next time when `next` is called + if *di_used != ma_used { + *di_used = -1; + panic!("dictionary changed size during iteration"); + }; + + // If the dict is changed in such a way that the length remains constant + // then this will panic at the end of iteration - similar to this: + // + // d = {"a":1, "b":2, "c": 3} + // + // for k, v in d.items(): + // d[f"{k}_"] = 4 + // del d[k] + // print(k) + // + if *remaining == -1 { + *di_used = -1; + panic!("dictionary keys changed during iteration"); + }; + + let mut key: *mut ffi::PyObject = std::ptr::null_mut(); + let mut value: *mut ffi::PyObject = std::ptr::null_mut(); + + if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 { + *remaining -= 1; + let py = dict.py(); + // Safety: + // - PyDict_Next returns borrowed values + // - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null + Some(( + unsafe { key.assume_borrowed_unchecked(py) }.to_owned(), + unsafe { value.assume_borrowed_unchecked(py) }.to_owned(), + )) + } else { + None + } + }), + } + } + + #[cfg(Py_GIL_DISABLED)] + #[inline] + fn with_critical_section(&mut self, dict: &Bound<'_, PyDict>, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + match self { + Self::DictIter { .. } => crate::sync::with_critical_section(dict, || f(self)), + } + } } impl<'py> Iterator for BoundDictIterator<'py> { @@ -413,50 +520,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { #[inline] fn next(&mut self) -> Option { - let ma_used = dict_len(&self.dict); - - // These checks are similar to what CPython does. - // - // If the dimension of the dict changes e.g. key-value pairs are removed - // or added during iteration, this will panic next time when `next` is called - if self.di_used != ma_used { - self.di_used = -1; - panic!("dictionary changed size during iteration"); - }; - - // If the dict is changed in such a way that the length remains constant - // then this will panic at the end of iteration - similar to this: - // - // d = {"a":1, "b":2, "c": 3} - // - // for k, v in d.items(): - // d[f"{k}_"] = 4 - // del d[k] - // print(k) - // - if self.len == -1 { - self.di_used = -1; - panic!("dictionary keys changed during iteration"); - }; - - let mut key: *mut ffi::PyObject = std::ptr::null_mut(); - let mut value: *mut ffi::PyObject = std::ptr::null_mut(); - - if unsafe { ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) } - != 0 - { - self.len -= 1; - let py = self.dict.py(); - // Safety: - // - PyDict_Next returns borrowed values - // - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null - Some(( - unsafe { key.assume_borrowed_unchecked(py) }.to_owned(), - unsafe { value.assume_borrowed_unchecked(py) }.to_owned(), - )) - } else { - None - } + self.inner.next(&self.dict) } #[inline] @@ -464,22 +528,147 @@ impl<'py> Iterator for BoundDictIterator<'py> { let len = self.len(); (len, Some(len)) } + + #[inline] + #[cfg(Py_GIL_DISABLED)] + fn fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + self.inner.with_critical_section(&self.dict, |inner| { + let mut accum = init; + while let Some(x) = inner.next(&self.dict) { + accum = f(accum, x); + } + accum + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, feature = "nightly"))] + fn try_fold(&mut self, init: B, mut f: F) -> R + where + Self: Sized, + F: FnMut(B, Self::Item) -> R, + R: std::ops::Try, + { + self.inner.with_critical_section(&self.dict, |inner| { + let mut accum = init; + while let Some(x) = inner.next(&self.dict) { + accum = f(accum, x)? + } + R::from_output(accum) + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn all(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.dict, |inner| { + while let Some(x) = inner.next(&self.dict) { + if !f(x) { + return false; + } + } + true + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn any(&mut self, mut f: F) -> bool + where + Self: Sized, + F: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.dict, |inner| { + while let Some(x) = inner.next(&self.dict) { + if f(x) { + return true; + } + } + false + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(&Self::Item) -> bool, + { + self.inner.with_critical_section(&self.dict, |inner| { + while let Some(x) = inner.next(&self.dict) { + if predicate(&x) { + return Some(x); + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn find_map(&mut self, mut f: F) -> Option + where + Self: Sized, + F: FnMut(Self::Item) -> Option, + { + self.inner.with_critical_section(&self.dict, |inner| { + while let Some(x) = inner.next(&self.dict) { + if let found @ Some(_) = f(x) { + return found; + } + } + None + }) + } + + #[inline] + #[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))] + fn position

(&mut self, mut predicate: P) -> Option + where + Self: Sized, + P: FnMut(Self::Item) -> bool, + { + self.inner.with_critical_section(&self.dict, |inner| { + let mut acc = 0; + while let Some(x) = inner.next(&self.dict) { + if predicate(x) { + return Some(acc); + } + acc += 1; + } + None + }) + } } impl ExactSizeIterator for BoundDictIterator<'_> { fn len(&self) -> usize { - self.len as usize + match self.inner { + DictIterImpl::DictIter { remaining, .. } => remaining as usize, + } } } impl<'py> BoundDictIterator<'py> { fn new(dict: Bound<'py, PyDict>) -> Self { - let len = dict_len(&dict); - BoundDictIterator { + let remaining = dict_len(&dict); + + Self { dict, - ppos: 0, - di_used: len, - len, + inner: DictIterImpl::DictIter { + ppos: 0, + di_used: remaining, + remaining, + }, } } } @@ -1360,4 +1549,107 @@ mod tests { ); }) } + + #[test] + fn test_iter_all() { + Python::with_gil(|py| { + let dict = [(1, true), (2, true), (3, true)].into_py_dict(py).unwrap(); + assert!(dict.iter().all(|(_, v)| v.extract::().unwrap())); + + let dict = [(1, true), (2, false), (3, true)].into_py_dict(py).unwrap(); + assert!(!dict.iter().all(|(_, v)| v.extract::().unwrap())); + }); + } + + #[test] + fn test_iter_any() { + Python::with_gil(|py| { + let dict = [(1, true), (2, false), (3, false)] + .into_py_dict(py) + .unwrap(); + assert!(dict.iter().any(|(_, v)| v.extract::().unwrap())); + + let dict = [(1, false), (2, false), (3, false)] + .into_py_dict(py) + .unwrap(); + assert!(!dict.iter().any(|(_, v)| v.extract::().unwrap())); + }); + } + + #[test] + #[allow(clippy::search_is_some)] + fn test_iter_find() { + Python::with_gil(|py| { + let dict = [(1, false), (2, true), (3, false)] + .into_py_dict(py) + .unwrap(); + + assert_eq!( + Some((2, true)), + dict.iter() + .find(|(_, v)| v.extract::().unwrap()) + .map(|(k, v)| (k.extract().unwrap(), v.extract().unwrap())) + ); + + let dict = [(1, false), (2, false), (3, false)] + .into_py_dict(py) + .unwrap(); + + assert!(dict + .iter() + .find(|(_, v)| v.extract::().unwrap()) + .is_none()); + }); + } + + #[test] + #[allow(clippy::search_is_some)] + fn test_iter_position() { + Python::with_gil(|py| { + let dict = [(1, false), (2, false), (3, true)] + .into_py_dict(py) + .unwrap(); + assert_eq!( + Some(2), + dict.iter().position(|(_, v)| v.extract::().unwrap()) + ); + + let dict = [(1, false), (2, false), (3, false)] + .into_py_dict(py) + .unwrap(); + assert!(dict + .iter() + .position(|(_, v)| v.extract::().unwrap()) + .is_none()); + }); + } + + #[test] + fn test_iter_fold() { + Python::with_gil(|py| { + let dict = [(1, 1), (2, 2), (3, 3)].into_py_dict(py).unwrap(); + let sum = dict + .iter() + .fold(0, |acc, (_, v)| acc + v.extract::().unwrap()); + assert_eq!(sum, 6); + }); + } + + #[test] + fn test_iter_try_fold() { + Python::with_gil(|py| { + let dict = [(1, 1), (2, 2), (3, 3)].into_py_dict(py).unwrap(); + let sum = dict + .iter() + .try_fold(0, |acc, (_, v)| PyResult::Ok(acc + v.extract::()?)) + .unwrap(); + assert_eq!(sum, 6); + + let dict = [(1, "foo"), (2, "bar")].into_py_dict(py).unwrap(); + assert!(dict + .iter() + .try_fold(0, |acc, (_, v)| PyResult::Ok(acc + v.extract::()?)) + .is_err()); + }); + } }