diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 076e155ba..53e06f2ce 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -46,6 +46,7 @@ use crate::iter::{ AxisChunksIterMut, AxisIter, AxisIterMut, + AxisWindows, ExactChunks, ExactChunksMut, IndexedIter, @@ -1521,7 +1522,7 @@ where /// assert_eq!(window.shape(), &[4, 3, 2]); /// } /// ``` - pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D> + pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> where S: Data { let axis_index = axis.index(); @@ -1537,10 +1538,7 @@ where self.shape() ); - let mut size = self.raw_dim(); - size[axis_index] = window_size; - - Windows::new(self.view(), size) + AxisWindows::new(self.view(), axis, window_size) } // Return (length, stride) for diagonal diff --git a/src/iterators/iter.rs b/src/iterators/iter.rs index 5c5acb9d7..478987ee0 100644 --- a/src/iterators/iter.rs +++ b/src/iterators/iter.rs @@ -13,6 +13,7 @@ pub use crate::iterators::{ AxisChunksIterMut, AxisIter, AxisIterMut, + AxisWindows, ExactChunks, ExactChunksIter, ExactChunksIterMut, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 4851b2827..d49ffe2d0 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -28,7 +28,7 @@ use super::{Dimension, Ix, Ixs}; pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut}; pub use self::into_iter::IntoIter; pub use self::lanes::{Lanes, LanesMut}; -pub use self::windows::Windows; +pub use self::windows::{AxisWindows, Windows}; use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index ec1afb634..453ef5024 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -41,41 +41,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> let strides = axis_strides.into_dimension(); let window_strides = a.strides.clone(); - ndassert!( - a.ndim() == window.ndim(), - concat!( - "Window dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), - window.ndim(), - a.ndim(), - a.shape() - ); - - ndassert!( - a.ndim() == strides.ndim(), - concat!( - "Stride dimension {} does not match array dimension {} ", - "(with array of shape {:?})" - ), - strides.ndim(), - a.ndim(), - a.shape() - ); - - let mut base = a; - base.slice_each_axis_inplace(|ax_desc| { - let len = ax_desc.len; - let wsz = window[ax_desc.axis.index()]; - let stride = strides[ax_desc.axis.index()]; - - if len < wsz { - Slice::new(0, Some(0), 1) - } else { - Slice::new(0, Some((len - wsz + 1) as isize), stride as isize) - } - }); - + let base = build_base(a, window.clone(), strides); Windows { base: base.into_raw_view(), life: PhantomData, @@ -160,3 +126,166 @@ impl_iterator! { send_sync_read_only!(Windows); send_sync_read_only!(WindowsIter); + +/// Window producer and iterable +/// +/// See [`.axis_windows()`](ArrayBase::axis_windows) for more +/// information. +pub struct AxisWindows<'a, A, D> +{ + base: ArrayView<'a, A, D>, + axis_idx: usize, + window: D, + strides: D, +} + +impl<'a, A, D: Dimension> AxisWindows<'a, A, D> +{ + pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self + { + let window_strides = a.strides.clone(); + let axis_idx = axis.index(); + + let mut window = a.raw_dim(); + window[axis_idx] = window_size; + + let ndim = window.ndim(); + let mut unit_stride = D::zeros(ndim); + unit_stride.slice_mut().fill(1); + + let base = build_base(a, window.clone(), unit_stride); + AxisWindows { + base, + axis_idx, + window, + strides: window_strides, + } + } +} + +impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> +{ + type Item = ArrayView<'a, A, D>; + type Dim = Ix1; + type Ptr = *mut A; + type Stride = isize; + + fn raw_dim(&self) -> Ix1 + { + Ix1(self.base.raw_dim()[self.axis_idx]) + } + + fn layout(&self) -> Layout + { + self.base.layout() + } + + fn as_ptr(&self) -> *mut A + { + self.base.as_ptr() as *mut _ + } + + fn contiguous_stride(&self) -> isize + { + self.base.contiguous_stride() + } + + unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item + { + ArrayView::new_(ptr, self.window.clone(), self.strides.clone()) + } + + unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A + { + let mut d = D::zeros(self.base.ndim()); + d[self.axis_idx] = i[0]; + self.base.uget_ptr(&d) + } + + fn stride_of(&self, axis: Axis) -> isize + { + assert_eq!(axis, Axis(0)); + self.base.stride_of(Axis(self.axis_idx)) + } + + fn split_at(self, axis: Axis, index: usize) -> (Self, Self) + { + assert_eq!(axis, Axis(0)); + let (a, b) = self.base.split_at(Axis(self.axis_idx), index); + ( + AxisWindows { + base: a, + axis_idx: self.axis_idx, + window: self.window.clone(), + strides: self.strides.clone(), + }, + AxisWindows { + base: b, + axis_idx: self.axis_idx, + window: self.window, + strides: self.strides, + }, + ) + } + + private_impl!{} +} + +impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D> +where + D: Dimension, + A: 'a, +{ + type Item = ::Item; + type IntoIter = WindowsIter<'a, A, D>; + fn into_iter(self) -> Self::IntoIter + { + WindowsIter { + iter: self.base.into_base_iter(), + life: PhantomData, + window: self.window, + strides: self.strides, + } + } +} + +/// build the base array of the `Windows` and `AxisWindows` structs +fn build_base(a: ArrayView, window: D, strides: D) -> ArrayView +where D: Dimension +{ + ndassert!( + a.ndim() == window.ndim(), + concat!( + "Window dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + window.ndim(), + a.ndim(), + a.shape() + ); + + ndassert!( + a.ndim() == strides.ndim(), + concat!( + "Stride dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + strides.ndim(), + a.ndim(), + a.shape() + ); + + let mut base = a; + base.slice_each_axis_inplace(|ax_desc| { + let len = ax_desc.len; + let wsz = window[ax_desc.axis.index()]; + let stride = strides[ax_desc.axis.index()]; + + if len < wsz { + Slice::new(0, Some(0), 1) + } else { + Slice::new(0, Some((len - wsz + 1) as isize), stride as isize) + } + }); + base +} diff --git a/tests/windows.rs b/tests/windows.rs index d8d5b699e..6506d8301 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -278,6 +278,22 @@ fn test_axis_windows_3d() ]); } +#[test] +fn tests_axis_windows_3d_zips_with_1d() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + let mut b = Array::zeros(2); + + Zip::from(b.view_mut()) + .and(a.axis_windows(Axis(1), 2)) + .for_each(|b, a| { + *b = a.sum(); + }); + assert_eq!(b,arr1(&[207, 261])); +} + #[test] fn test_window_neg_stride() {