@@ -1877,4 +1877,136 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
1877
1877
}
1878
1878
} )
1879
1879
}
1880
+
1881
+ /// Traverse an axis, and 'cumulatively fold' over self, i.e.
1882
+ /// return an array A where the element at index i of `axis`,
1883
+ /// and index [..] of other axes, is the result of
1884
+ /// A.axis_iter(axis)
1885
+ /// .skip(1)
1886
+ /// .take(i-1)
1887
+ /// .fold(A.subview(axis, 0)[..], |acc, subview| f(acc, subview[..])).
1888
+ ///
1889
+ /// **panics** if the dimension of `self` along `axis` is 0.
1890
+ ///
1891
+ /// # Example
1892
+ /// ```
1893
+ /// use ndarray::{arr2, Axis};
1894
+ /// use std::ops::Add;
1895
+ ///
1896
+ /// let a = arr2(&[[1., 2.],
1897
+ /// [3., 4.],
1898
+ /// [5., 6.]]);
1899
+ /// let accumulated = a.accumulate_axis(Axis(0), |a, b| a+b);
1900
+ /// assert!(accumulated
1901
+ /// .all_close(&arr2(&[[1., 2.],
1902
+ /// [4., 6.],
1903
+ /// [9., 12.]]), 1e-12));
1904
+ ///
1905
+ /// let b = arr1(&[1., 2., 3., 4.]);
1906
+ /// let b_accumulated = b.accumulate_axis_inplace(Axis(0), |a, b| a*b);
1907
+ /// assert!(b_accumulated.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
1908
+ /// ```
1909
+ pub fn accumulate_axis < F > ( & self , axis : Axis , mut f : F ) -> Array < A , D >
1910
+ where
1911
+ A : Clone ,
1912
+ D : :: dimension:: RemoveAxis ,
1913
+ F : FnMut ( & A , & A ) -> A ,
1914
+ {
1915
+ let mut accum = unsafe {
1916
+ let mut v = Vec :: with_capacity ( self . len ( ) ) ;
1917
+ v. set_len ( self . len ( ) ) ;
1918
+ Array :: < A , D > :: from_shape_vec_unchecked ( self . dim ( ) , v)
1919
+ } ;
1920
+ let mut states = self . subview ( axis, 0 ) . to_owned ( ) ;
1921
+ accum. subview_mut ( axis, 0 ) . assign ( & states) ;
1922
+ for ( mut accum_i, self_i) in accum. axis_iter_mut ( axis) . skip ( 1 )
1923
+ . zip ( self . axis_iter ( axis) . skip ( 1 ) ) {
1924
+ Zip :: from ( & mut accum_i)
1925
+ . and ( & mut states)
1926
+ . and ( & self_i)
1927
+ . apply ( |ac, st, se| { * st = f ( st, se) ; * ac = st. clone ( ) ; } ) ;
1928
+ }
1929
+ accum
1930
+ }
1931
+
1932
+ /// Inplace version of `accumulate_axis`. See that method for more
1933
+ /// documentation.
1934
+ ///
1935
+ /// **panics** if the dimension of `self` along `axis` is 0.
1936
+ ///
1937
+ /// # Example
1938
+ /// ```
1939
+ /// use ndarray::{arr1, arr2, Axis};
1940
+ /// use std::ops::Add;
1941
+ ///
1942
+ /// let mut a = arr2(&[[1., 2.],
1943
+ /// [3., 4.],
1944
+ /// [5., 6.]]);
1945
+ /// a.accumulate_axis_inplace(Axis(0), |a, b| a+b);
1946
+ /// assert!(a
1947
+ /// .all_close(&arr2(&[[1., 2.],
1948
+ /// [4., 6.],
1949
+ /// [9., 12.]]), 1e-12));
1950
+ ///
1951
+ /// let mut b = arr1(&[1., 2., 3., 4.]);
1952
+ /// b.accumulate_axis_inplace(Axis(0), |a, b| a*b);
1953
+ /// assert!(b.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
1954
+ /// ```
1955
+ pub fn accumulate_axis_inplace < F > ( & mut self , axis : Axis , mut f : F )
1956
+ where
1957
+ A : Clone ,
1958
+ D : :: dimension:: RemoveAxis ,
1959
+ F : FnMut ( & A , & A ) -> A ,
1960
+ S : :: data_traits:: DataMut ,
1961
+ {
1962
+ let mut states = self . subview ( axis, 0 ) . to_owned ( ) ;
1963
+ for mut self_i in self . axis_iter_mut ( axis) . skip ( 1 ) {
1964
+ Zip :: from ( & mut states)
1965
+ . and ( & mut self_i)
1966
+ . apply ( |st, se| { * st = f ( st, se) ; * se = st. clone ( ) ; } ) ;
1967
+ }
1968
+ }
1969
+
1970
+ /// Traverse an axis, applying f to each element and returning the result.
1971
+ /// Maintains a mutable copy of `initial_state` for each element in the subview
1972
+ /// obtained by traversing `self`.
1973
+ ///
1974
+ /// This function is similar to `accumulate_axis`, but allows for a different
1975
+ /// output type.
1976
+ ///
1977
+ /// # Example
1978
+ ///
1979
+ /// ```
1980
+ /// use ndarray::{arr2, Axis};
1981
+ ///
1982
+ /// let a = arr2(&[[1., 2.],
1983
+ /// [3., 4.],
1984
+ /// [5., 6.]]);
1985
+ /// let scanned = a.scan_axis(Axis(0), 0., |acc, x| {*acc += x; *acc as i32});
1986
+ /// assert_eq!((scanned - arr2(&[[1, 2],
1987
+ /// [4, 6],
1988
+ /// [9, 12]])).mapv(i32::abs).scalar_sum(), 0);
1989
+ /// ```
1990
+ pub fn scan_axis < St , B , F > ( & self , axis : Axis , initial_state : St , mut f : F )
1991
+ -> Array < B , D >
1992
+ where
1993
+ B : Clone ,
1994
+ D : :: dimension:: RemoveAxis ,
1995
+ F : FnMut ( & mut St , & A ) -> B ,
1996
+ St : Copy ,
1997
+ {
1998
+ let mut accum = unsafe {
1999
+ let mut v = Vec :: with_capacity ( self . len ( ) ) ;
2000
+ v. set_len ( self . len ( ) ) ;
2001
+ Array :: < B , D > :: from_shape_vec_unchecked ( self . dim ( ) , v)
2002
+ } ;
2003
+ let mut states = Array :: < St , _ > :: from_elem ( self . dim . remove_axis ( axis) , initial_state) ;
2004
+ for ( mut accum_i, self_i) in accum. axis_iter_mut ( axis) . zip ( self . axis_iter ( axis) ) {
2005
+ Zip :: from ( & mut accum_i)
2006
+ . and ( & mut states)
2007
+ . and ( & self_i)
2008
+ . apply ( |ac, st, se| * ac = f ( st, se) ) ;
2009
+ }
2010
+ accum
2011
+ }
1880
2012
}
0 commit comments