diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 477d93e71..d9738ad02 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1877,4 +1877,136 @@ impl ArrayBase where S: Data, D: Dimension } }) } + + /// Traverse an axis, and 'cumulatively fold' over self, i.e. + /// return an array A where the element at index i of `axis`, + /// and index [..] of other axes, is the result of + /// A.axis_iter(axis) + /// .skip(1) + /// .take(i-1) + /// .fold(A.subview(axis, 0)[..], |acc, subview| f(acc, subview[..])). + /// + /// **panics** if the dimension of `self` along `axis` is 0. + /// + /// # Example + /// ``` + /// use ndarray::{arr1, arr2, Axis}; + /// use std::ops::Add; + /// + /// let a = arr2(&[[1., 2.], + /// [3., 4.], + /// [5., 6.]]); + /// let accumulated = a.accumulate_axis(Axis(0), |a, b| a+b); + /// assert!(accumulated + /// .all_close(&arr2(&[[1., 2.], + /// [4., 6.], + /// [9., 12.]]), 1e-12)); + /// + /// let b = arr1(&[1., 2., 3., 4.]); + /// let b_accumulated = b.accumulate_axis(Axis(0), |a, b| a*b); + /// assert!(b_accumulated.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12)); + /// ``` + pub fn accumulate_axis(&self, axis: Axis, mut f: F) -> Array + where + A: Clone, + D: ::dimension::RemoveAxis, + F: FnMut(&A, &A) -> A, + { + let mut accum = unsafe{ + let mut v = Vec::with_capacity(self.len()); + v.set_len(self.len()); + Array::::from_shape_vec_unchecked(self.dim(), v) + }; + let mut states = self.subview(axis, 0).to_owned(); + accum.subview_mut(axis, 0).assign(&states); + for (mut accum_i, self_i) in accum.axis_iter_mut(axis).skip(1) + .zip(self.axis_iter(axis).skip(1)) { + Zip::from(&mut accum_i) + .and(&mut states) + .and(&self_i) + .apply(|ac, st, se| {*st = f(st, se); *ac = st.clone();}); + } + accum + } + + /// Inplace version of `accumulate_axis`. See that method for more + /// documentation. + /// + /// **panics** if the dimension of `self` along `axis` is 0. + /// + /// # Example + /// ``` + /// use ndarray::{arr1, arr2, Axis}; + /// use std::ops::Add; + /// + /// let mut a = arr2(&[[1., 2.], + /// [3., 4.], + /// [5., 6.]]); + /// a.accumulate_axis_inplace(Axis(0), |a, b| a+b); + /// assert!(a + /// .all_close(&arr2(&[[1., 2.], + /// [4., 6.], + /// [9., 12.]]), 1e-12)); + /// + /// let mut b = arr1(&[1., 2., 3., 4.]); + /// b.accumulate_axis_inplace(Axis(0), |a, b| a*b); + /// assert!(b.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12)); + /// ``` + pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) + where + A: Clone, + D: ::dimension::RemoveAxis, + F: FnMut(&A, &A) -> A, + S: ::data_traits::DataMut, + { + let mut states = self.subview(axis, 0).to_owned(); + for mut self_i in self.axis_iter_mut(axis).skip(1) { + Zip::from(&mut states) + .and(&mut self_i) + .apply(|st, se| {*st = f(st, se); *se = st.clone();}); + } + } + + /// Traverse an axis, applying f to each element and returning the result. + /// Maintains a mutable copy of `initial_state` for each element in the subview + /// obtained by traversing `self`. + /// + /// This function is similar to `accumulate_axis`, but allows for a different + /// output type. + /// + /// # Example + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2.], + /// [3., 4.], + /// [5., 6.]]); + /// let scanned = a.scan_axis(Axis(0), 0., |acc, x| {*acc += x; *acc as i32}); + /// assert_eq!((scanned - arr2(&[[1, 2], + /// [4, 6], + /// [9, 12]])).mapv(i32::abs).scalar_sum(), 0); + /// ``` + pub fn scan_axis(&self, axis: Axis, initial_state: St, mut f: F) + -> Array + where + B: Clone, + D: ::dimension::RemoveAxis, + F: FnMut(&mut St, &A) -> B, + St: Copy, + { + let mut accum = unsafe{ + let mut v = Vec::with_capacity(self.len()); + v.set_len(self.len()); + Array::::from_shape_vec_unchecked(self.dim(), v) + }; + let mut states = Array::::from_elem(self.dim.remove_axis(axis), initial_state); + for (mut accum_i, self_i) in accum.axis_iter_mut(axis).zip(self.axis_iter(axis)) { + Zip::from(&mut accum_i) + .and(&mut states) + .and(&self_i) + .apply(|ac, st, se| *ac = f(st, se)); + } + accum + } }