Skip to content

Commit abb7bef

Browse files
committed
implement accumulate_axis, accumulate_axis_inplace, scan_axis
1 parent 408f42b commit abb7bef

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/impl_methods.rs

+132
Original file line numberDiff line numberDiff line change
@@ -1877,4 +1877,136 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
18771877
}
18781878
})
18791879
}
1880+
1881+
/// Traverse an axis, and 'cumulatively fold' over self, i.e.
1882+
/// return an array A where the element at index i of `axis`,
1883+
/// and index [..] of other axes, is the result of
1884+
/// A.axis_iter(axis)
1885+
/// .skip(1)
1886+
/// .take(i-1)
1887+
/// .fold(A.subview(axis, 0)[..], |acc, subview| f(acc, subview[..])).
1888+
///
1889+
/// **panics** if the dimension of `self` along `axis` is 0.
1890+
///
1891+
/// # Example
1892+
/// ```
1893+
/// use ndarray::{arr2, Axis};
1894+
/// use std::ops::Add;
1895+
///
1896+
/// let a = arr2(&[[1., 2.],
1897+
/// [3., 4.],
1898+
/// [5., 6.]]);
1899+
/// let accumulated = a.accumulate_axis(Axis(0), |a, b| a+b);
1900+
/// assert!(accumulated
1901+
/// .all_close(&arr2(&[[1., 2.],
1902+
/// [4., 6.],
1903+
/// [9., 12.]]), 1e-12));
1904+
///
1905+
/// let b = arr1(&[1., 2., 3., 4.]);
1906+
/// let b_accumulated = b.accumulate_axis_inplace(Axis(0), |a, b| a*b);
1907+
/// assert!(b_accumulated.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
1908+
/// ```
1909+
pub fn accumulate_axis<F>(&self, axis: Axis, mut f: F) -> Array<A, D>
1910+
where
1911+
A: Clone,
1912+
D: ::dimension::RemoveAxis,
1913+
F: FnMut(&A, &A) -> A,
1914+
{
1915+
let mut accum = unsafe{
1916+
let mut v = Vec::with_capacity(self.len());
1917+
v.set_len(self.len());
1918+
Array::<A, D>::from_shape_vec_unchecked(self.dim(), v)
1919+
};
1920+
let mut states = self.subview(axis, 0).to_owned();
1921+
accum.subview_mut(axis, 0).assign(&states);
1922+
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).skip(1)
1923+
.zip(self.axis_iter(axis).skip(1)) {
1924+
Zip::from(&mut accum_i)
1925+
.and(&mut states)
1926+
.and(&self_i)
1927+
.apply(|ac, st, se| {*st = f(st, se); *ac = st.clone();});
1928+
}
1929+
accum
1930+
}
1931+
1932+
/// Inplace version of `accumulate_axis`. See that method for more
1933+
/// documentation.
1934+
///
1935+
/// **panics** if the dimension of `self` along `axis` is 0.
1936+
///
1937+
/// # Example
1938+
/// ```
1939+
/// use ndarray::{arr1, arr2, Axis};
1940+
/// use std::ops::Add;
1941+
///
1942+
/// let mut a = arr2(&[[1., 2.],
1943+
/// [3., 4.],
1944+
/// [5., 6.]]);
1945+
/// a.accumulate_axis_inplace(Axis(0), |a, b| a+b);
1946+
/// assert!(a
1947+
/// .all_close(&arr2(&[[1., 2.],
1948+
/// [4., 6.],
1949+
/// [9., 12.]]), 1e-12));
1950+
///
1951+
/// let mut b = arr1(&[1., 2., 3., 4.]);
1952+
/// b.accumulate_axis_inplace(Axis(0), |a, b| a*b);
1953+
/// assert!(b.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
1954+
/// ```
1955+
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
1956+
where
1957+
A: Clone,
1958+
D: ::dimension::RemoveAxis,
1959+
F: FnMut(&A, &A) -> A,
1960+
S: ::data_traits::DataMut,
1961+
{
1962+
let mut states = self.subview(axis, 0).to_owned();
1963+
for mut self_i in self.axis_iter_mut(axis).skip(1) {
1964+
Zip::from(&mut states)
1965+
.and(&mut self_i)
1966+
.apply(|st, se| {*st = f(st, se); *se = st.clone();});
1967+
}
1968+
}
1969+
1970+
/// Traverse an axis, applying f to each element and returning the result.
1971+
/// Maintains a mutable copy of `initial_state` for each element in the subview
1972+
/// obtained by traversing `self`.
1973+
///
1974+
/// This function is similar to `accumulate_axis`, but allows for a different
1975+
/// output type.
1976+
///
1977+
/// # Example
1978+
///
1979+
/// ```
1980+
/// use ndarray::{arr2, Axis};
1981+
///
1982+
/// let a = arr2(&[[1., 2.],
1983+
/// [3., 4.],
1984+
/// [5., 6.]]);
1985+
/// let scanned = a.scan_axis(Axis(0), 0., |acc, x| {*acc += x; *acc as i32});
1986+
/// assert_eq!((scanned - arr2(&[[1, 2],
1987+
/// [4, 6],
1988+
/// [9, 12]])).mapv(i32::abs).scalar_sum(), 0);
1989+
/// ```
1990+
pub fn scan_axis<St, B, F>(&self, axis: Axis, initial_state: St, mut f: F)
1991+
-> Array<B, D>
1992+
where
1993+
B: Clone,
1994+
D: ::dimension::RemoveAxis,
1995+
F: FnMut(&mut St, &A) -> B,
1996+
St: Copy,
1997+
{
1998+
let mut accum = unsafe{
1999+
let mut v = Vec::with_capacity(self.len());
2000+
v.set_len(self.len());
2001+
Array::<B, D>::from_shape_vec_unchecked(self.dim(), v)
2002+
};
2003+
let mut states = Array::<St, _>::from_elem(self.dim.remove_axis(axis), initial_state);
2004+
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).zip(self.axis_iter(axis)) {
2005+
Zip::from(&mut accum_i)
2006+
.and(&mut states)
2007+
.and(&self_i)
2008+
.apply(|ac, st, se| *ac = f(st, se));
2009+
}
2010+
accum
2011+
}
18802012
}

0 commit comments

Comments
 (0)