Skip to content

Commit

Permalink
FEAT: Implement order optimizations for Baseiter
Browse files Browse the repository at this point in the history
Implement axis merging - this preserves order of elements in the
iteration but might simplify iteration. For example, in a contiguous
matrix, a shape like [3, 4] can be merged into [1, 12].

Also allow arbitrary order optimization - we then try to iterate in
memory order by sorting all axes, currently.
  • Loading branch information
bluss committed Mar 31, 2024
1 parent d42ee96 commit 19fab7e
Showing 1 changed file with 227 additions and 8 deletions.
235 changes: 227 additions & 8 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,59 @@ use alloc::vec::Vec;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr;
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

use crate::Ix1;
use crate::imp_prelude::*;

use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis};
use super::{Dimension, Ix, Ixs};
use super::NdProducer;

pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
pub use self::into_iter::IntoIter;
pub use self::lanes::{Lanes, LanesMut};
pub use self::windows::Windows;

use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
use crate::dimension;

/// No traversal optmizations that would change element order or axis dimensions are permitted.
///
/// This option is suitable for example for the indexed iterator.
pub(crate) enum NoOptimization {}

/// Preserve element iteration order, but modify dimensions if profitable; for example we can
/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
///
/// This option is suitable for example for the default .iter() iterator.
pub(crate) enum PreserveOrder {}

/// Allow use of arbitrary element iteration order
///
/// This option is suitable for example for an arbitrary order iterator.
pub(crate) enum ArbitraryOrder {}

pub(crate) trait OrderOption
{
const ALLOW_REMOVE_REDUNDANT_AXES: bool = false;
const ALLOW_ARBITRARY_ORDER: bool = false;
}

impl OrderOption for NoOptimization {}

impl OrderOption for PreserveOrder
{
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
}

impl OrderOption for ArbitraryOrder
{
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
const ALLOW_ARBITRARY_ORDER: bool = true;
}

/// Base for iterators over all axes.
///
/// Iterator element type is `*mut A`.
///
/// `F` is for layout/iteration order flags
#[derive(Debug)]
pub(crate) struct Baseiter<A, D>
{
Expand All @@ -50,13 +87,46 @@ impl<A, D: Dimension> Baseiter<A, D>
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D>
{
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
}
}

impl<A, D: Dimension> Baseiter<A, D>
{
/// Creating a Baseiter is unsafe because shape and stride parameters need
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new_with_order<Flags: OrderOption>(mut ptr: *mut A, mut dim: D, mut strides: D) -> Baseiter<A, D>
{
debug_assert_eq!(dim.ndim(), strides.ndim());
if Flags::ALLOW_ARBITRARY_ORDER {
// iterate in memory order; merge axes if possible
// make all axes positive and put the pointer back to the first element in memory
let offset = dimension::offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides);
ptr = ptr.sub(offset);
for i in 0..strides.ndim() {
let s = strides.get_stride(Axis(i));
if s < 0 {
strides.set_stride(Axis(i), -s);
}
}
dimension::sort_axes_to_standard(&mut dim, &mut strides);
}

if Flags::ALLOW_REMOVE_REDUNDANT_AXES {
// preserve element order but shift dimensions
dimension::merge_axes_from_the_back(&mut dim, &mut strides);
dimension::squeeze(&mut dim, &mut strides);
}

Baseiter {
ptr,
index: len.first_index(),
dim: len,
strides: stride,
index: dim.first_index(),
dim,
strides,
}
}
}
Expand Down Expand Up @@ -1585,3 +1655,152 @@ where
debug_assert_eq!(size, result.len());
result
}

#[cfg(test)]
#[cfg(feature = "std")]
mod tests
{
use super::Baseiter;
use super::{ArbitraryOrder, NoOptimization, PreserveOrder};
use crate::prelude::*;
use itertools::assert_equal;
use itertools::Itertools;

// 3-d axis swaps
fn swaps() -> impl Iterator<Item = Vec<(usize, usize)>>
{
vec![
vec![],
vec![(0, 1)],
vec![(0, 2)],
vec![(1, 2)],
vec![(0, 1), (1, 2)],
vec![(0, 1), (0, 2)],
]
.into_iter()
}

// 3-d axis inverts
fn inverts() -> impl Iterator<Item = Vec<Axis>>
{
vec![
vec![],
vec![Axis(0)],
vec![Axis(1)],
vec![Axis(2)],
vec![Axis(0), Axis(1)],
vec![Axis(0), Axis(2)],
vec![Axis(1), Axis(2)],
vec![Axis(0), Axis(1), Axis(2)],
]
.into_iter()
}

#[test]
fn test_arbitrary_order()
{
for swap in swaps() {
for invert in inverts() {
for &slice in &[false, true] {
// pattern is 0, 1; 4, 5; 8, 9; etc..
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
if slice {
a.slice_collapse(s![.., ..;2, ..]);
}
for &(i, j) in &swap {
a.swap_axes(i, j);
}
for &i in &invert {
a.invert_axis(i);
}
unsafe {
// Should have in-memory order for arbitrary order
let iter = Baseiter::new_with_order::<ArbitraryOrder>(a.as_mut_ptr(), a.dim, a.strides);
if !slice {
assert_equal(iter.map(|ptr| *ptr), 0..a.len());
} else {
assert_eq!(iter.map(|ptr| *ptr).collect_vec(),
(0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec());
}
}
}
}
}
}

#[test]
fn test_logical_order()
{
for swap in swaps() {
for invert in inverts() {
for &slice in &[false, true] {
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
for &(i, j) in &swap {
a.swap_axes(i, j);
}
for &i in &invert {
a.invert_axis(i);
}
if slice {
a.slice_collapse(s![.., ..;2, ..]);
}

unsafe {
let mut iter = Baseiter::new_with_order::<NoOptimization>(a.as_mut_ptr(), a.dim, a.strides);
let mut index = Dim([0, 0, 0]);
let mut elts = 0;
while let Some(elt) = iter.next() {
assert_eq!(*elt, a[index]);
if let Some(index_) = a.raw_dim().next_for(index) {
index = index_;
}
elts += 1;
}
assert_eq!(elts, a.len());
}
}
}
}
}

#[test]
fn test_preserve_order()
{
for swap in swaps() {
for invert in inverts() {
for &slice in &[false, true] {
let mut a = Array::from_iter(0..20).into_shape((2, 10, 1)).unwrap();
for &(i, j) in &swap {
a.swap_axes(i, j);
}
for &i in &invert {
a.invert_axis(i);
}
if slice {
a.slice_collapse(s![.., ..;2, ..]);
}

unsafe {
let mut iter = Baseiter::new_with_order::<PreserveOrder>(a.as_mut_ptr(), a.dim, a.strides);

// check that axes have been merged (when it's easy to check)
if a.shape() == &[2, 10, 1] && invert.is_empty() {
assert_eq!(iter.dim, Dim([1, 1, 20]));
}

let mut index = Dim([0, 0, 0]);
let mut elts = 0;
while let Some(elt) = iter.next() {
assert_eq!(*elt, a[index]);
if let Some(index_) = a.raw_dim().next_for(index) {
index = index_;
}
elts += 1;
}
assert_eq!(elts, a.len());
}
}
}
}
}
}

0 comments on commit 19fab7e

Please sign in to comment.