-
Notifications
You must be signed in to change notification settings - Fork 321
Accumulate methods #513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accumulate methods #513
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1877,4 +1877,136 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension | |
} | ||
}) | ||
} | ||
|
||
/// Traverse an axis, and 'cumulatively fold' over self, i.e. | ||
/// return an array A where the element at index i of `axis`, | ||
/// and index [..] of other axes, is the result of | ||
/// A.axis_iter(axis) | ||
/// .skip(1) | ||
/// .take(i-1) | ||
/// .fold(A.subview(axis, 0)[..], |acc, subview| f(acc, subview[..])). | ||
/// | ||
/// **panics** if the dimension of `self` along `axis` is 0. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// use ndarray::{arr1, arr2, Axis}; | ||
/// use std::ops::Add; | ||
/// | ||
/// let a = arr2(&[[1., 2.], | ||
/// [3., 4.], | ||
/// [5., 6.]]); | ||
/// let accumulated = a.accumulate_axis(Axis(0), |a, b| a+b); | ||
/// assert!(accumulated | ||
/// .all_close(&arr2(&[[1., 2.], | ||
/// [4., 6.], | ||
/// [9., 12.]]), 1e-12)); | ||
/// | ||
/// let b = arr1(&[1., 2., 3., 4.]); | ||
/// let b_accumulated = b.accumulate_axis(Axis(0), |a, b| a*b); | ||
/// assert!(b_accumulated.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12)); | ||
/// ``` | ||
pub fn accumulate_axis<F>(&self, axis: Axis, mut f: F) -> Array<A, D> | ||
where | ||
A: Clone, | ||
D: ::dimension::RemoveAxis, | ||
F: FnMut(&A, &A) -> A, | ||
{ | ||
let mut accum = unsafe{ | ||
let mut v = Vec::with_capacity(self.len()); | ||
v.set_len(self.len()); | ||
Array::<A, D>::from_shape_vec_unchecked(self.dim(), v) | ||
}; | ||
let mut states = self.subview(axis, 0).to_owned(); | ||
accum.subview_mut(axis, 0).assign(&states); | ||
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).skip(1) | ||
.zip(self.axis_iter(axis).skip(1)) { | ||
Zip::from(&mut accum_i) | ||
.and(&mut states) | ||
.and(&self_i) | ||
.apply(|ac, st, se| {*st = f(st, se); *ac = st.clone();}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This results in undefined behavior if |
||
} | ||
accum | ||
} | ||
|
||
/// Inplace version of `accumulate_axis`. See that method for more | ||
/// documentation. | ||
/// | ||
/// **panics** if the dimension of `self` along `axis` is 0. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not necessary to panic in this case. We can add an explicit check for this case and return an empty array of the correct shape. |
||
/// | ||
/// # Example | ||
/// ``` | ||
/// use ndarray::{arr1, arr2, Axis}; | ||
/// use std::ops::Add; | ||
/// | ||
/// let mut a = arr2(&[[1., 2.], | ||
/// [3., 4.], | ||
/// [5., 6.]]); | ||
/// a.accumulate_axis_inplace(Axis(0), |a, b| a+b); | ||
/// assert!(a | ||
/// .all_close(&arr2(&[[1., 2.], | ||
/// [4., 6.], | ||
/// [9., 12.]]), 1e-12)); | ||
/// | ||
/// let mut b = arr1(&[1., 2., 3., 4.]); | ||
/// b.accumulate_axis_inplace(Axis(0), |a, b| a*b); | ||
/// assert!(b.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12)); | ||
/// ``` | ||
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F) | ||
where | ||
A: Clone, | ||
D: ::dimension::RemoveAxis, | ||
F: FnMut(&A, &A) -> A, | ||
S: ::data_traits::DataMut, | ||
{ | ||
let mut states = self.subview(axis, 0).to_owned(); | ||
for mut self_i in self.axis_iter_mut(axis).skip(1) { | ||
Zip::from(&mut states) | ||
.and(&mut self_i) | ||
.apply(|st, se| {*st = f(st, se); *se = st.clone();}); | ||
} | ||
} | ||
|
||
/// Traverse an axis, applying f to each element and returning the result. | ||
/// Maintains a mutable copy of `initial_state` for each element in the subview | ||
/// obtained by traversing `self`. | ||
/// | ||
/// This function is similar to `accumulate_axis`, but allows for a different | ||
/// output type. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use ndarray::{arr2, Axis}; | ||
/// | ||
/// let a = arr2(&[[1., 2.], | ||
/// [3., 4.], | ||
/// [5., 6.]]); | ||
/// let scanned = a.scan_axis(Axis(0), 0., |acc, x| {*acc += x; *acc as i32}); | ||
/// assert_eq!((scanned - arr2(&[[1, 2], | ||
/// [4, 6], | ||
/// [9, 12]])).mapv(i32::abs).scalar_sum(), 0); | ||
/// ``` | ||
pub fn scan_axis<St, B, F>(&self, axis: Axis, initial_state: St, mut f: F) | ||
-> Array<B, D> | ||
where | ||
B: Clone, | ||
D: ::dimension::RemoveAxis, | ||
F: FnMut(&mut St, &A) -> B, | ||
St: Copy, | ||
{ | ||
let mut accum = unsafe{ | ||
let mut v = Vec::with_capacity(self.len()); | ||
v.set_len(self.len()); | ||
Array::<B, D>::from_shape_vec_unchecked(self.dim(), v) | ||
}; | ||
let mut states = Array::<St, _>::from_elem(self.dim.remove_axis(axis), initial_state); | ||
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).zip(self.axis_iter(axis)) { | ||
Zip::from(&mut accum_i) | ||
.and(&mut states) | ||
.and(&self_i) | ||
.apply(|ac, st, se| *ac = f(st, se)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This results in undefined behavior if |
||
} | ||
accum | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This results in undefined behavior if
A
implementsDrop
and the length of axisaxis
is zero. (If the length of the axis is zero, this line will panic, which will cause the uninitialized arrayaccum
to be dropped, which in turn will cause each of the uninitialized elements to be dropped (which is undefined behavior).)