diff --git a/benches/to_shape.rs b/benches/to_shape.rs new file mode 100644 index 000000000..a048eb774 --- /dev/null +++ b/benches/to_shape.rs @@ -0,0 +1,106 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; +use ndarray::Order; + +#[bench] +fn to_shape2_1(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(4 * 5).unwrap() + }); +} + +#[bench] +fn to_shape2_2_same(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((4, 5)).unwrap() + }); +} + +#[bench] +fn to_shape2_2_flip(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((5, 4)).unwrap() + }); +} + +#[bench] +fn to_shape2_3(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((2, 5, 2)).unwrap() + }); +} + +#[bench] +fn to_shape3_1(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(3 * 4 * 5).unwrap() + }); +} + +#[bench] +fn to_shape3_2_order(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((12, 5)).unwrap() + }); +} + +#[bench] +fn to_shape3_2_outoforder(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((4, 15)).unwrap() + }); +} + +#[bench] +fn to_shape3_3c(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((3, 4, 5)).unwrap() + }); +} + +#[bench] +fn to_shape3_3f(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| { + view.to_shape(((3, 4, 5), Order::F)).unwrap() + }); +} + +#[bench] +fn to_shape3_4c(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(((2, 3, 2, 5), Order::C)).unwrap() + }); +} + +#[bench] +fn to_shape3_4f(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| { + view.to_shape(((2, 3, 2, 5), Order::F)).unwrap() + }); +} diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index f4f46e764..4aa7c6641 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -9,10 +9,10 @@ use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::slice::SliceArg; use crate::{Ix, Ixs, Slice, SliceInfoElem}; +use crate::shape_builder::Strides; use num_integer::div_floor; pub use self::axes::{Axes, AxisDescription}; -pub(crate) use self::axes::axes_of; pub use self::axis::Axis; pub use self::broadcast::DimMax; pub use self::conversion::IntoDimension; @@ -23,7 +23,8 @@ pub use self::ndindex::NdIndex; pub use self::ops::DimAdd; pub use self::remove_axis::RemoveAxis; -use crate::shape_builder::Strides; +pub(crate) use self::axes::axes_of; +pub(crate) use self::reshape::reshape_dim; use std::isize; use std::mem; @@ -40,6 +41,8 @@ mod dynindeximpl; mod ndindex; mod ops; mod remove_axis; +pub(crate) mod reshape; +mod sequence; /// Calculate offset from `Ix` stride converting sign properly #[inline(always)] diff --git a/src/dimension/reshape.rs b/src/dimension/reshape.rs new file mode 100644 index 000000000..c6e08848d --- /dev/null +++ b/src/dimension/reshape.rs @@ -0,0 +1,241 @@ + +use crate::{Dimension, Order, ShapeError, ErrorKind}; +use crate::dimension::sequence::{Sequence, SequenceMut, Forward, Reverse}; + +#[inline] +pub(crate) fn reshape_dim(from: &D, strides: &D, to: &E, order: Order) + -> Result +where + D: Dimension, + E: Dimension, +{ + debug_assert_eq!(from.ndim(), strides.ndim()); + let mut to_strides = E::zeros(to.ndim()); + match order { + Order::RowMajor => { + reshape_dim_c(&Forward(from), &Forward(strides), + &Forward(to), Forward(&mut to_strides))?; + } + Order::ColumnMajor => { + reshape_dim_c(&Reverse(from), &Reverse(strides), + &Reverse(to), Reverse(&mut to_strides))?; + } + } + Ok(to_strides) +} + +/// Try to reshape an array with dimensions `from_dim` and strides `from_strides` to the new +/// dimension `to_dim`, while keeping the same layout of elements in memory. The strides needed +/// if this is possible are stored into `to_strides`. +/// +/// This function uses RowMajor index ordering if the inputs are read in the forward direction +/// (index 0 is axis 0 etc) and ColumnMajor index ordering if the inputs are read in reversed +/// direction (as made possible with the Sequence trait). +/// +/// Preconditions: +/// +/// 1. from_dim and to_dim are valid dimensions (product of all non-zero axes +/// fits in isize::MAX). +/// 2. from_dim and to_dim are don't have any axes that are zero (that should be handled before +/// this function). +/// 3. `to_strides` should be an all-zeros or all-ones dimension of the right dimensionality +/// (but it will be overwritten after successful exit of this function). +/// +/// This function returns: +/// +/// - IncompatibleShape if the two shapes are not of matching number of elements +/// - IncompatibleLayout if the input shape and stride can not be remapped to the output shape +/// without moving the array data into a new memory layout. +/// - Ok if the from dim could be mapped to the new to dim. +fn reshape_dim_c(from_dim: &D, from_strides: &D, to_dim: &E, mut to_strides: E2) + -> Result<(), ShapeError> +where + D: Sequence, + E: Sequence, + E2: SequenceMut, +{ + // cursor indexes into the from and to dimensions + let mut fi = 0; // index into `from_dim` + let mut ti = 0; // index into `to_dim`. + + while fi < from_dim.len() && ti < to_dim.len() { + let mut fd = from_dim[fi]; + let mut fs = from_strides[fi] as isize; + let mut td = to_dim[ti]; + + if fd == td { + to_strides[ti] = from_strides[fi]; + fi += 1; + ti += 1; + continue + } + + if fd == 1 { + fi += 1; + continue; + } + + if td == 1 { + to_strides[ti] = 1; + ti += 1; + continue; + } + + if fd == 0 || td == 0 { + debug_assert!(false, "zero dim not handled by this function"); + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + // stride times element count is to be distributed out over a combination of axes. + let mut fstride_whole = fs * (fd as isize); + let mut fd_product = fd; // cumulative product of axis lengths in the combination (from) + let mut td_product = td; // cumulative product of axis lengths in the combination (to) + + // The two axis lengths are not a match, so try to combine multiple axes + // to get it to match up. + while fd_product != td_product { + if fd_product < td_product { + // Take another axis on the from side + fi += 1; + if fi >= from_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + fd = from_dim[fi]; + fd_product *= fd; + if fd > 1 { + let fs_old = fs; + fs = from_strides[fi] as isize; + // check if this axis and the next are contiguous together + if fs_old != fd as isize * fs { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)); + } + } + } else { + // Take another axis on the `to` side + // First assign the stride to the axis we leave behind + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + ti += 1; + if ti >= to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + td = to_dim[ti]; + td_product *= td; + } + } + + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + + fi += 1; + ti += 1; + } + + // skip past 1-dims at the end + while fi < from_dim.len() && from_dim[fi] == 1 { + fi += 1; + } + + while ti < to_dim.len() && to_dim[ti] == 1 { + to_strides[ti] = 1; + ti += 1; + } + + if fi < from_dim.len() || ti < to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + Ok(()) +} + +#[cfg(feature = "std")] +#[test] +fn test_reshape() { + use crate::Dim; + + macro_rules! test_reshape { + (fail $order:ident from $from:expr, $stride:expr, to $to:expr) => { + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + let _res = res.expect_err("Expected failed reshape"); + }; + (ok $order:ident from $from:expr, $stride:expr, to $to:expr, $to_stride:expr) => {{ + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + println!("default stride for from dim: {:?}", Dim($from).default_strides()); + println!("default stride for to dim: {:?}", Dim($to).default_strides()); + let res = res.expect("Expected successful reshape"); + assert_eq!(res, Dim($to_stride), "mismatch in strides"); + }}; + } + + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [1, 2, 3], [6, 3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [2, 3], [3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [6], [1]); + test_reshape!(fail C from [1, 2, 3], [6, 3, 1], to [1]); + test_reshape!(fail F from [1, 2, 3], [6, 3, 1], to [1]); + + test_reshape!(ok C from [6], [1], to [3, 2], [2, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4, 1], [8, 4, 1, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4], [8, 4, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 2, 2], [8, 4, 2, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 1, 4], [8, 4, 1, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + test_reshape!(ok C from [3, 4, 4], [16, 4, 1], to [3, 16], [16, 1]); + + test_reshape!(ok C from [4, 4], [8, 1], to [2, 2, 2, 2], [16, 8, 2, 1]); + + test_reshape!(fail C from [4, 4], [8, 1], to [2, 1, 4, 2]); + + test_reshape!(ok C from [16], [4], to [2, 2, 4], [32, 16, 4]); + test_reshape!(ok C from [16], [-4isize as usize], to [2, 2, 4], + [-32isize as usize, -16isize as usize, -4isize as usize]); + test_reshape!(ok F from [16], [4], to [2, 2, 4], [4, 8, 16]); + test_reshape!(ok F from [16], [-4isize as usize], to [2, 2, 4], + [-4isize as usize, -8isize as usize, -16isize as usize]); + + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [12, 5], [5, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + test_reshape!(fail F from [3, 4, 5], [20, 5, 1], to [4, 15]); + test_reshape!(ok C from [3, 4, 5, 7], [140, 35, 7, 1], to [28, 15], [15, 1]); + + // preserve stride if shape matches + test_reshape!(ok C from [10], [2], to [10], [2]); + test_reshape!(ok F from [10], [2], to [10], [2]); + test_reshape!(ok C from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok F from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + test_reshape!(ok F from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + + test_reshape!(ok C from [3, 4, 5], [4, 1, 1], to [12, 5], [1, 1]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 12], to [12, 5], [1, 12]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 1], to [12, 5], [1, 1]); + + // broadcast shapes + test_reshape!(ok C from [3, 4, 5, 7], [0, 0, 7, 1], to [12, 35], [0, 1]); + test_reshape!(fail C from [3, 4, 5, 7], [0, 0, 7, 1], to [28, 15]); + + // one-filled shapes + test_reshape!(ok C from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok F from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok F from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 2, 2, 2, 1]); + test_reshape!(ok F from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 1, 5, 5, 5]); + test_reshape!(ok C from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok F from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok C from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10], [1]); + test_reshape!(fail F from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10]); + test_reshape!(ok F from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10], [1]); + test_reshape!(fail C from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10]); +} + diff --git a/src/dimension/sequence.rs b/src/dimension/sequence.rs new file mode 100644 index 000000000..835e00d18 --- /dev/null +++ b/src/dimension/sequence.rs @@ -0,0 +1,109 @@ +use std::ops::Index; +use std::ops::IndexMut; + +use crate::dimension::Dimension; + +pub(in crate::dimension) struct Forward(pub(crate) D); +pub(in crate::dimension) struct Reverse(pub(crate) D); + +impl Index for Forward<&D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[index] + } +} + +impl Index for Forward<&mut D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[index] + } +} + +impl IndexMut for Forward<&mut D> +where + D: Dimension, +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize { + &mut self.0[index] + } +} + +impl Index for Reverse<&D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[self.len() - index - 1] + } +} + +impl Index for Reverse<&mut D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[self.len() - index - 1] + } +} + +impl IndexMut for Reverse<&mut D> +where + D: Dimension, +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize { + let len = self.len(); + &mut self.0[len - index - 1] + } +} + +/// Indexable sequence with length +pub(in crate::dimension) trait Sequence: Index { + fn len(&self) -> usize; +} + +/// Indexable sequence with length (mut) +pub(in crate::dimension) trait SequenceMut: Sequence + IndexMut { } + +impl Sequence for Forward<&D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl Sequence for Forward<&mut D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl SequenceMut for Forward<&mut D> where D: Dimension { } + +impl Sequence for Reverse<&D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl Sequence for Reverse<&mut D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl SequenceMut for Reverse<&mut D> where D: Dimension { } + diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9ef4277de..6c51b4515 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -23,11 +23,14 @@ use crate::dimension::{ offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; +use crate::dimension::reshape_dim; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; -use crate::zip::{IntoNdProducer, Zip}; use crate::AxisDescription; +use crate::order::Order; +use crate::shape_builder::ShapeArg; +use crate::zip::{IntoNdProducer, Zip}; use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, @@ -1577,6 +1580,101 @@ where } } + /// Transform the array into `new_shape`; any shape with the same number of elements is + /// accepted. + /// + /// `order` specifies the *logical* order in which the array is to be read and reshaped. + /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. + /// + /// For example, when starting from the one-dimensional sequence 1 2 3 4 5 6, it would be + /// understood as a 2 x 3 array in row major ("C") order this way: + /// + /// ```text + /// 1 2 3 + /// 4 5 6 + /// ``` + /// + /// and as 2 x 3 in column major ("F") order this way: + /// + /// ```text + /// 1 3 5 + /// 2 4 6 + /// ``` + /// + /// This example should show that any time we "reflow" the elements in the array to a different + /// number of rows and columns (or more axes if applicable), it is important to pick an index + /// ordering, and that's the reason for the function parameter for `order`. + /// + /// **Errors** if the new shape doesn't have the same number of elements as the array's current + /// shape. + /// + /// ``` + /// use ndarray::array; + /// use ndarray::Order; + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::RowMajor)).unwrap() + /// == array![[1., 2., 3.], + /// [4., 5., 6.]] + /// ); + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::ColumnMajor)).unwrap() + /// == array![[1., 3., 5.], + /// [2., 4., 6.]] + /// ); + /// ``` + pub fn to_shape(&self, new_shape: E) -> Result, ShapeError> + where + E: ShapeArg, + A: Clone, + S: Data, + { + let (shape, order) = new_shape.into_shape_and_order(); + self.to_shape_order(shape, order.unwrap_or(Order::RowMajor)) + } + + fn to_shape_order(&self, shape: E, order: Order) + -> Result, ShapeError> + where + E: Dimension, + A: Clone, + S: Data, + { + let len = self.dim.size(); + if size_of_shape_checked(&shape) != Ok(len) { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + + // Create a view if the length is 0, safe because the array and new shape is empty. + if len == 0 { + unsafe { + return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr()))); + } + } + + // Try to reshape the array as a view into the existing data + match reshape_dim(&self.dim, &self.strides, &shape, order) { + Ok(to_strides) => unsafe { + return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides))); + } + Err(err) if err.kind() == ErrorKind::IncompatibleShape => { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + _otherwise => { } + } + + // otherwise create a new array and copy the elements + unsafe { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( + shape, view.into_iter(), A::clone))) + } + } + /// Transform the array into `shape`; any shape with the same number of /// elements is accepted, but the source array or view must be in standard /// or column-major (Fortran) layout. diff --git a/src/lib.rs b/src/lib.rs index a4c59eb66..dfce924e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,6 +145,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; +pub use crate::order::Order; pub use crate::slice::{ MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim, }; @@ -162,7 +163,7 @@ pub use crate::stacking::{concatenate, stack, stack_new_axis}; pub use crate::math_cell::MathCell; pub use crate::impl_views::IndexLonger; -pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape}; +pub use crate::shape_builder::{Shape, ShapeBuilder, ShapeArg, StrideShape}; #[macro_use] mod macro_utils; @@ -202,6 +203,7 @@ mod linspace; mod logspace; mod math_cell; mod numeric_util; +mod order; mod partial; mod shape_builder; #[macro_use] diff --git a/src/order.rs b/src/order.rs new file mode 100644 index 000000000..e8d9c8db1 --- /dev/null +++ b/src/order.rs @@ -0,0 +1,83 @@ + +/// Array order +/// +/// Order refers to indexing order, or how a linear sequence is translated +/// into a two-dimensional or multi-dimensional array. +/// +/// - `RowMajor` means that the index along the row is the most rapidly changing +/// - `ColumnMajor` means that the index along the column is the most rapidly changing +/// +/// Given a sequence like: 1, 2, 3, 4, 5, 6 +/// +/// If it is laid it out in a 2 x 3 matrix using row major ordering, it results in: +/// +/// ```text +/// 1 2 3 +/// 4 5 6 +/// ``` +/// +/// If it is laid using column major ordering, it results in: +/// +/// ```text +/// 1 3 5 +/// 2 4 6 +/// ``` +/// +/// It can be seen as filling in "rows first" or "columns first". +/// +/// `Order` can be used both to refer to logical ordering as well as memory ordering or memory +/// layout. The orderings have common short names, also seen in other environments, where +/// row major is called "C" order (after the C programming language) and column major is called "F" +/// or "Fortran" order. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Order { + /// Row major or "C" order + RowMajor, + /// Column major or "F" order + ColumnMajor, +} + +impl Order { + /// "C" is an alias for row major ordering + pub const C: Order = Order::RowMajor; + + /// "F" (for Fortran) is an alias for column major ordering + pub const F: Order = Order::ColumnMajor; + + /// Return true if input is Order::RowMajor, false otherwise + #[inline] + pub fn is_row_major(self) -> bool { + match self { + Order::RowMajor => true, + Order::ColumnMajor => false, + } + } + + /// Return true if input is Order::ColumnMajor, false otherwise + #[inline] + pub fn is_column_major(self) -> bool { + !self.is_row_major() + } + + /// Return Order::RowMajor if the input is true, Order::ColumnMajor otherwise + #[inline] + pub fn row_major(row_major: bool) -> Order { + if row_major { Order::RowMajor } else { Order::ColumnMajor } + } + + /// Return Order::ColumnMajor if the input is true, Order::RowMajor otherwise + #[inline] + pub fn column_major(column_major: bool) -> Order { + Self::row_major(!column_major) + } + + /// Return the transpose: row major becomes column major and vice versa. + #[inline] + pub fn transpose(self) -> Order { + match self { + Order::RowMajor => Order::ColumnMajor, + Order::ColumnMajor => Order::RowMajor, + } + } +} diff --git a/src/shape_builder.rs b/src/shape_builder.rs index dcfddc1b9..470374077 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -1,5 +1,6 @@ use crate::dimension::IntoDimension; use crate::Dimension; +use crate::order::Order; /// A contiguous array shape of n dimensions. /// @@ -184,3 +185,33 @@ where self.dim.size() } } + + +/// Array shape argument with optional order parameter +/// +/// Shape or array dimension argument, with optional [`Order`] parameter. +/// +/// This is an argument conversion trait that is used to accept an array shape and +/// (optionally) an ordering argument. +/// +/// See for example [`.to_shape()`](crate::ArrayBase::to_shape). +pub trait ShapeArg { + type Dim: Dimension; + fn into_shape_and_order(self) -> (Self::Dim, Option); +} + +impl ShapeArg for T where T: IntoDimension { + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) { + (self.into_dimension(), None) + } +} + +impl ShapeArg for (T, Order) where T: IntoDimension { + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) { + (self.0.into_dimension(), Some(self.1)) + } +} diff --git a/tests/array.rs b/tests/array.rs index 976824dfe..edd58adbc 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -8,7 +8,7 @@ )] use defmac::defmac; -use itertools::{enumerate, zip, Itertools}; +use itertools::{zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::indices; @@ -1370,64 +1370,6 @@ fn transpose_view_mut() { assert_eq!(at, arr2(&[[1, 4], [2, 5], [3, 7]])); } -#[test] -fn reshape() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let u = v.into_shape((3, 3)); - assert!(u.is_err()); - let u = v.into_shape((2, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - assert_eq!(u.shape(), &[2, 2, 2]); - let s = u.into_shape((4, 2)).unwrap(); - assert_eq!(s.shape(), &[4, 2]); - assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); -} - -#[test] -#[should_panic(expected = "IncompatibleShape")] -fn reshape_error1() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let _u = v.into_shape((2, 5)).unwrap(); -} - -#[test] -#[should_panic(expected = "IncompatibleLayout")] -fn reshape_error2() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let mut u = v.into_shape((2, 2, 2)).unwrap(); - u.swap_axes(0, 1); - let _s = u.into_shape((2, 4)).unwrap(); -} - -#[test] -fn reshape_f() { - let mut u = Array::zeros((3, 4).f()); - for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { - *elt = i as i32; - } - let v = u.view(); - println!("{:?}", v); - - // noop ok - let v2 = v.into_shape((3, 4)); - assert!(v2.is_ok()); - assert_eq!(v, v2.unwrap()); - - let u = v.into_shape((3, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - println!("{:?}", u); - assert_eq!(u.shape(), &[3, 2, 2]); - let s = u.into_shape((4, 3)).unwrap(); - println!("{:?}", s); - assert_eq!(s.shape(), &[4, 3]); - assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); -} - #[test] #[allow(clippy::cognitive_complexity)] fn insert_axis() { diff --git a/tests/reshape.rs b/tests/reshape.rs new file mode 100644 index 000000000..21fe407ea --- /dev/null +++ b/tests/reshape.rs @@ -0,0 +1,232 @@ +use ndarray::prelude::*; + +use itertools::enumerate; + +use ndarray::Order; + +#[test] +fn reshape() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.into_shape((3, 3)); + assert!(u.is_err()); + let u = v.into_shape((2, 2, 2)); + assert!(u.is_ok()); + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + let s = u.into_shape((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn reshape_error1() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.into_shape((2, 5)).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleLayout")] +fn reshape_error2() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let mut u = v.into_shape((2, 2, 2)).unwrap(); + u.swap_axes(0, 1); + let _s = u.into_shape((2, 4)).unwrap(); +} + +#[test] +fn reshape_f() { + let mut u = Array::zeros((3, 4).f()); + for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { + *elt = i as i32; + } + let v = u.view(); + println!("{:?}", v); + + // noop ok + let v2 = v.into_shape((3, 4)); + assert!(v2.is_ok()); + assert_eq!(v, v2.unwrap()); + + let u = v.into_shape((3, 2, 2)); + assert!(u.is_ok()); + let u = u.unwrap(); + println!("{:?}", u); + assert_eq!(u.shape(), &[3, 2, 2]); + let s = u.into_shape((4, 3)).unwrap(); + println!("{:?}", s); + assert_eq!(s.shape(), &[4, 3]); + assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); +} + + +#[test] +fn to_shape_easy() { + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::RowMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::C)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + let s = u.to_shape((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); + + // 1D -> F -> F + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::ColumnMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::ColumnMajor)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + + let s = u.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]); +} + +#[test] +fn to_shape_copy() { + // 1D -> C -> F + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 2], [3, 4], [5, 6], [7, 8]]); + + let u = u.to_shape(((2, 4), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); + + // 1D -> F -> C + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 5], [2, 6], [3, 7], [4, 8]]); + + let u = u.to_shape((2, 4)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); +} + +#[test] +fn to_shape_add_axis() { + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + + assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); + assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view()); +} + + +#[test] +fn to_shape_copy_stride() { + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..3]); + let lin1 = vs.to_shape(6).unwrap(); + assert_eq!(lin1, array![1, 2, 3, 5, 6, 7]); + assert!(lin1.is_owned()); + + let lin2 = vs.to_shape((6, Order::ColumnMajor)).unwrap(); + assert_eq!(lin2, array![1, 5, 2, 6, 3, 7]); + assert!(lin2.is_owned()); +} + + +#[test] +fn to_shape_zero_len() { + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..0]); + let lin1 = vs.to_shape(0).unwrap(); + assert_eq!(lin1, array![]); + assert!(lin1.is_view()); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error1() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, 5)).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error2() { + // overflow + let data = [3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, usize::MAX)).unwrap(); +} + +#[test] +fn to_shape_discontig() { + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..;2, ..]); // now shape (4, 2, 4) + assert!(a1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v1 = a1.to_shape(((2, 2, 2, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 1, 2, 1, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 2, 4), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((8, 4), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::C, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape(((4, 8), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::F, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape((32, order)).unwrap(); + assert!(!v1.is_view()); + } + } +} + +#[test] +fn to_shape_broadcast() { + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..1, ..]); // now shape (4, 1, 4) + let v1 = a1.broadcast((4, 4, 4)).unwrap(); // Now shape (4, 4, 4) + assert!(v1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v2 = v1.to_shape(((2, 2, 2, 2, 2, 2), order)).unwrap(); + assert_eq!(v2.strides(), match (create_order, order) { + (Order::C, Order::C) => { &[32, 16, 0, 0, 2, 1] } + (Order::C, Order::F) => { &[16, 32, 0, 0, 1, 2] } + (Order::F, Order::C) => { &[2, 1, 0, 0, 32, 16] } + (Order::F, Order::F) => { &[1, 2, 0, 0, 16, 32] } + _other => unreachable!() + }); + + let v2 = v1.to_shape(((4, 4, 4), order)).unwrap(); + assert!(v2.is_view()); + let v2 = v1.to_shape(((8, 8), order)).unwrap(); + assert!(v2.is_owned()); + } + } +}