diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index e03199dba..010e4dd45 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -19,22 +19,59 @@ use alloc::vec::Vec; use std::iter::FromIterator; use std::marker::PhantomData; use std::ptr; +use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; -use crate::Ix1; +use crate::imp_prelude::*; -use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis}; -use super::{Dimension, Ix, Ixs}; +use super::NdProducer; pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut}; pub use self::into_iter::IntoIter; pub use self::lanes::{Lanes, LanesMut}; pub use self::windows::Windows; -use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; +use crate::dimension; + +/// No traversal optmizations that would change element order or axis dimensions are permitted. +/// +/// This option is suitable for example for the indexed iterator. +pub(crate) enum NoOptimization {} + +/// Preserve element iteration order, but modify dimensions if profitable; for example we can +/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here. +/// +/// This option is suitable for example for the default .iter() iterator. +pub(crate) enum PreserveOrder {} + +/// Allow use of arbitrary element iteration order +/// +/// This option is suitable for example for an arbitrary order iterator. +pub(crate) enum ArbitraryOrder {} + +pub(crate) trait OrderOption +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = false; + const ALLOW_ARBITRARY_ORDER: bool = false; +} + +impl OrderOption for NoOptimization {} + +impl OrderOption for PreserveOrder +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = true; +} + +impl OrderOption for ArbitraryOrder +{ + const ALLOW_REMOVE_REDUNDANT_AXES: bool = true; + const ALLOW_ARBITRARY_ORDER: bool = true; +} /// Base for iterators over all axes. /// /// Iterator element type is `*mut A`. +/// +/// `F` is for layout/iteration order flags #[derive(Debug)] pub(crate) struct Baseiter { @@ -50,13 +87,46 @@ impl Baseiter /// to be correct to avoid performing an unsafe pointer offset while /// iterating. #[inline] - pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter + pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter { + Self::new_with_order::(ptr, dim, strides) + } +} + +impl Baseiter +{ + /// Creating a Baseiter is unsafe because shape and stride parameters need + /// to be correct to avoid performing an unsafe pointer offset while + /// iterating. + #[inline] + pub unsafe fn new_with_order(mut ptr: *mut A, mut dim: D, mut strides: D) -> Baseiter + { + debug_assert_eq!(dim.ndim(), strides.ndim()); + if Flags::ALLOW_ARBITRARY_ORDER { + // iterate in memory order; merge axes if possible + // make all axes positive and put the pointer back to the first element in memory + let offset = dimension::offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides); + ptr = ptr.sub(offset); + for i in 0..strides.ndim() { + let s = strides.get_stride(Axis(i)); + if s < 0 { + strides.set_stride(Axis(i), -s); + } + } + dimension::sort_axes_to_standard(&mut dim, &mut strides); + } + + if Flags::ALLOW_REMOVE_REDUNDANT_AXES { + // preserve element order but shift dimensions + dimension::merge_axes_from_the_back(&mut dim, &mut strides); + dimension::squeeze(&mut dim, &mut strides); + } + Baseiter { ptr, - index: len.first_index(), - dim: len, - strides: stride, + index: dim.first_index(), + dim, + strides, } } } @@ -1585,3 +1655,152 @@ where debug_assert_eq!(size, result.len()); result } + +#[cfg(test)] +#[cfg(feature = "std")] +mod tests +{ + use super::Baseiter; + use super::{ArbitraryOrder, NoOptimization, PreserveOrder}; + use crate::prelude::*; + use itertools::assert_equal; + use itertools::Itertools; + + // 3-d axis swaps + fn swaps() -> impl Iterator> + { + vec![ + vec![], + vec![(0, 1)], + vec![(0, 2)], + vec![(1, 2)], + vec![(0, 1), (1, 2)], + vec![(0, 1), (0, 2)], + ] + .into_iter() + } + + // 3-d axis inverts + fn inverts() -> impl Iterator> + { + vec![ + vec![], + vec![Axis(0)], + vec![Axis(1)], + vec![Axis(2)], + vec![Axis(0), Axis(1)], + vec![Axis(0), Axis(2)], + vec![Axis(1), Axis(2)], + vec![Axis(0), Axis(1), Axis(2)], + ] + .into_iter() + } + + #[test] + fn test_arbitrary_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + // pattern is 0, 1; 4, 5; 8, 9; etc.. + let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap(); + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + unsafe { + // Should have in-memory order for arbitrary order + let iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + if !slice { + assert_equal(iter.map(|ptr| *ptr), 0..a.len()); + } else { + assert_eq!(iter.map(|ptr| *ptr).collect_vec(), + (0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec()); + } + } + } + } + } + } + + #[test] + fn test_logical_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap(); + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + + unsafe { + let mut iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + let mut index = Dim([0, 0, 0]); + let mut elts = 0; + while let Some(elt) = iter.next() { + assert_eq!(*elt, a[index]); + if let Some(index_) = a.raw_dim().next_for(index) { + index = index_; + } + elts += 1; + } + assert_eq!(elts, a.len()); + } + } + } + } + } + + #[test] + fn test_preserve_order() + { + for swap in swaps() { + for invert in inverts() { + for &slice in &[false, true] { + let mut a = Array::from_iter(0..20).into_shape((2, 10, 1)).unwrap(); + for &(i, j) in &swap { + a.swap_axes(i, j); + } + for &i in &invert { + a.invert_axis(i); + } + if slice { + a.slice_collapse(s![.., ..;2, ..]); + } + + unsafe { + let mut iter = Baseiter::new_with_order::(a.as_mut_ptr(), a.dim, a.strides); + + // check that axes have been merged (when it's easy to check) + if a.shape() == &[2, 10, 1] && invert.is_empty() { + assert_eq!(iter.dim, Dim([1, 1, 20])); + } + + let mut index = Dim([0, 0, 0]); + let mut elts = 0; + while let Some(elt) = iter.next() { + assert_eq!(*elt, a[index]); + if let Some(index_) = a.raw_dim().next_for(index) { + index = index_; + } + elts += 1; + } + assert_eq!(elts, a.len()); + } + } + } + } + } +}