Skip to content

Commit 19fab7e

Browse files
committed
FEAT: Implement order optimizations for Baseiter
Implement axis merging - this preserves order of elements in the iteration but might simplify iteration. For example, in a contiguous matrix, a shape like [3, 4] can be merged into [1, 12]. Also allow arbitrary order optimization - we then try to iterate in memory order by sorting all axes, currently.
1 parent d42ee96 commit 19fab7e

File tree

1 file changed

+227
-8
lines changed

1 file changed

+227
-8
lines changed

src/iterators/mod.rs

+227-8
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,59 @@ use alloc::vec::Vec;
1919
use std::iter::FromIterator;
2020
use std::marker::PhantomData;
2121
use std::ptr;
22+
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
2223

23-
use crate::Ix1;
24+
use crate::imp_prelude::*;
2425

25-
use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis};
26-
use super::{Dimension, Ix, Ixs};
26+
use super::NdProducer;
2727

2828
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2929
pub use self::into_iter::IntoIter;
3030
pub use self::lanes::{Lanes, LanesMut};
3131
pub use self::windows::Windows;
3232

33-
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
33+
use crate::dimension;
34+
35+
/// No traversal optmizations that would change element order or axis dimensions are permitted.
36+
///
37+
/// This option is suitable for example for the indexed iterator.
38+
pub(crate) enum NoOptimization {}
39+
40+
/// Preserve element iteration order, but modify dimensions if profitable; for example we can
41+
/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
42+
///
43+
/// This option is suitable for example for the default .iter() iterator.
44+
pub(crate) enum PreserveOrder {}
45+
46+
/// Allow use of arbitrary element iteration order
47+
///
48+
/// This option is suitable for example for an arbitrary order iterator.
49+
pub(crate) enum ArbitraryOrder {}
50+
51+
pub(crate) trait OrderOption
52+
{
53+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = false;
54+
const ALLOW_ARBITRARY_ORDER: bool = false;
55+
}
56+
57+
impl OrderOption for NoOptimization {}
58+
59+
impl OrderOption for PreserveOrder
60+
{
61+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
62+
}
63+
64+
impl OrderOption for ArbitraryOrder
65+
{
66+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
67+
const ALLOW_ARBITRARY_ORDER: bool = true;
68+
}
3469

3570
/// Base for iterators over all axes.
3671
///
3772
/// Iterator element type is `*mut A`.
73+
///
74+
/// `F` is for layout/iteration order flags
3875
#[derive(Debug)]
3976
pub(crate) struct Baseiter<A, D>
4077
{
@@ -50,13 +87,46 @@ impl<A, D: Dimension> Baseiter<A, D>
5087
/// to be correct to avoid performing an unsafe pointer offset while
5188
/// iterating.
5289
#[inline]
53-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
90+
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D>
5491
{
92+
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
93+
}
94+
}
95+
96+
impl<A, D: Dimension> Baseiter<A, D>
97+
{
98+
/// Creating a Baseiter is unsafe because shape and stride parameters need
99+
/// to be correct to avoid performing an unsafe pointer offset while
100+
/// iterating.
101+
#[inline]
102+
pub unsafe fn new_with_order<Flags: OrderOption>(mut ptr: *mut A, mut dim: D, mut strides: D) -> Baseiter<A, D>
103+
{
104+
debug_assert_eq!(dim.ndim(), strides.ndim());
105+
if Flags::ALLOW_ARBITRARY_ORDER {
106+
// iterate in memory order; merge axes if possible
107+
// make all axes positive and put the pointer back to the first element in memory
108+
let offset = dimension::offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides);
109+
ptr = ptr.sub(offset);
110+
for i in 0..strides.ndim() {
111+
let s = strides.get_stride(Axis(i));
112+
if s < 0 {
113+
strides.set_stride(Axis(i), -s);
114+
}
115+
}
116+
dimension::sort_axes_to_standard(&mut dim, &mut strides);
117+
}
118+
119+
if Flags::ALLOW_REMOVE_REDUNDANT_AXES {
120+
// preserve element order but shift dimensions
121+
dimension::merge_axes_from_the_back(&mut dim, &mut strides);
122+
dimension::squeeze(&mut dim, &mut strides);
123+
}
124+
55125
Baseiter {
56126
ptr,
57-
index: len.first_index(),
58-
dim: len,
59-
strides: stride,
127+
index: dim.first_index(),
128+
dim,
129+
strides,
60130
}
61131
}
62132
}
@@ -1585,3 +1655,152 @@ where
15851655
debug_assert_eq!(size, result.len());
15861656
result
15871657
}
1658+
1659+
#[cfg(test)]
1660+
#[cfg(feature = "std")]
1661+
mod tests
1662+
{
1663+
use super::Baseiter;
1664+
use super::{ArbitraryOrder, NoOptimization, PreserveOrder};
1665+
use crate::prelude::*;
1666+
use itertools::assert_equal;
1667+
use itertools::Itertools;
1668+
1669+
// 3-d axis swaps
1670+
fn swaps() -> impl Iterator<Item = Vec<(usize, usize)>>
1671+
{
1672+
vec![
1673+
vec![],
1674+
vec![(0, 1)],
1675+
vec![(0, 2)],
1676+
vec![(1, 2)],
1677+
vec![(0, 1), (1, 2)],
1678+
vec![(0, 1), (0, 2)],
1679+
]
1680+
.into_iter()
1681+
}
1682+
1683+
// 3-d axis inverts
1684+
fn inverts() -> impl Iterator<Item = Vec<Axis>>
1685+
{
1686+
vec![
1687+
vec![],
1688+
vec![Axis(0)],
1689+
vec![Axis(1)],
1690+
vec![Axis(2)],
1691+
vec![Axis(0), Axis(1)],
1692+
vec![Axis(0), Axis(2)],
1693+
vec![Axis(1), Axis(2)],
1694+
vec![Axis(0), Axis(1), Axis(2)],
1695+
]
1696+
.into_iter()
1697+
}
1698+
1699+
#[test]
1700+
fn test_arbitrary_order()
1701+
{
1702+
for swap in swaps() {
1703+
for invert in inverts() {
1704+
for &slice in &[false, true] {
1705+
// pattern is 0, 1; 4, 5; 8, 9; etc..
1706+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1707+
if slice {
1708+
a.slice_collapse(s![.., ..;2, ..]);
1709+
}
1710+
for &(i, j) in &swap {
1711+
a.swap_axes(i, j);
1712+
}
1713+
for &i in &invert {
1714+
a.invert_axis(i);
1715+
}
1716+
unsafe {
1717+
// Should have in-memory order for arbitrary order
1718+
let iter = Baseiter::new_with_order::<ArbitraryOrder>(a.as_mut_ptr(), a.dim, a.strides);
1719+
if !slice {
1720+
assert_equal(iter.map(|ptr| *ptr), 0..a.len());
1721+
} else {
1722+
assert_eq!(iter.map(|ptr| *ptr).collect_vec(),
1723+
(0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec());
1724+
}
1725+
}
1726+
}
1727+
}
1728+
}
1729+
}
1730+
1731+
#[test]
1732+
fn test_logical_order()
1733+
{
1734+
for swap in swaps() {
1735+
for invert in inverts() {
1736+
for &slice in &[false, true] {
1737+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1738+
for &(i, j) in &swap {
1739+
a.swap_axes(i, j);
1740+
}
1741+
for &i in &invert {
1742+
a.invert_axis(i);
1743+
}
1744+
if slice {
1745+
a.slice_collapse(s![.., ..;2, ..]);
1746+
}
1747+
1748+
unsafe {
1749+
let mut iter = Baseiter::new_with_order::<NoOptimization>(a.as_mut_ptr(), a.dim, a.strides);
1750+
let mut index = Dim([0, 0, 0]);
1751+
let mut elts = 0;
1752+
while let Some(elt) = iter.next() {
1753+
assert_eq!(*elt, a[index]);
1754+
if let Some(index_) = a.raw_dim().next_for(index) {
1755+
index = index_;
1756+
}
1757+
elts += 1;
1758+
}
1759+
assert_eq!(elts, a.len());
1760+
}
1761+
}
1762+
}
1763+
}
1764+
}
1765+
1766+
#[test]
1767+
fn test_preserve_order()
1768+
{
1769+
for swap in swaps() {
1770+
for invert in inverts() {
1771+
for &slice in &[false, true] {
1772+
let mut a = Array::from_iter(0..20).into_shape((2, 10, 1)).unwrap();
1773+
for &(i, j) in &swap {
1774+
a.swap_axes(i, j);
1775+
}
1776+
for &i in &invert {
1777+
a.invert_axis(i);
1778+
}
1779+
if slice {
1780+
a.slice_collapse(s![.., ..;2, ..]);
1781+
}
1782+
1783+
unsafe {
1784+
let mut iter = Baseiter::new_with_order::<PreserveOrder>(a.as_mut_ptr(), a.dim, a.strides);
1785+
1786+
// check that axes have been merged (when it's easy to check)
1787+
if a.shape() == &[2, 10, 1] && invert.is_empty() {
1788+
assert_eq!(iter.dim, Dim([1, 1, 20]));
1789+
}
1790+
1791+
let mut index = Dim([0, 0, 0]);
1792+
let mut elts = 0;
1793+
while let Some(elt) = iter.next() {
1794+
assert_eq!(*elt, a[index]);
1795+
if let Some(index_) = a.raw_dim().next_for(index) {
1796+
index = index_;
1797+
}
1798+
elts += 1;
1799+
}
1800+
assert_eq!(elts, a.len());
1801+
}
1802+
}
1803+
}
1804+
}
1805+
}
1806+
}

0 commit comments

Comments
 (0)