Skip to content

Accumulate methods #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1877,4 +1877,136 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
}
})
}

/// Traverse an axis, and 'cumulatively fold' over self, i.e.
/// return an array A where the element at index i of `axis`,
/// and index [..] of other axes, is the result of
/// A.axis_iter(axis)
/// .skip(1)
/// .take(i-1)
/// .fold(A.subview(axis, 0)[..], |acc, subview| f(acc, subview[..])).
///
/// **panics** if the dimension of `self` along `axis` is 0.
///
/// # Example
/// ```
/// use ndarray::{arr1, arr2, Axis};
/// use std::ops::Add;
///
/// let a = arr2(&[[1., 2.],
/// [3., 4.],
/// [5., 6.]]);
/// let accumulated = a.accumulate_axis(Axis(0), |a, b| a+b);
/// assert!(accumulated
/// .all_close(&arr2(&[[1., 2.],
/// [4., 6.],
/// [9., 12.]]), 1e-12));
///
/// let b = arr1(&[1., 2., 3., 4.]);
/// let b_accumulated = b.accumulate_axis(Axis(0), |a, b| a*b);
/// assert!(b_accumulated.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
/// ```
pub fn accumulate_axis<F>(&self, axis: Axis, mut f: F) -> Array<A, D>
where
A: Clone,
D: ::dimension::RemoveAxis,
F: FnMut(&A, &A) -> A,
{
let mut accum = unsafe{
let mut v = Vec::with_capacity(self.len());
v.set_len(self.len());
Array::<A, D>::from_shape_vec_unchecked(self.dim(), v)
};
let mut states = self.subview(axis, 0).to_owned();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in undefined behavior if A implements Drop and the length of axis axis is zero. (If the length of the axis is zero, this line will panic, which will cause the uninitialized array accum to be dropped, which in turn will cause each of the uninitialized elements to be dropped (which is undefined behavior).)

accum.subview_mut(axis, 0).assign(&states);
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).skip(1)
.zip(self.axis_iter(axis).skip(1)) {
Zip::from(&mut accum_i)
.and(&mut states)
.and(&self_i)
.apply(|ac, st, se| {*st = f(st, se); *ac = st.clone();});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in undefined behavior if A implements Drop, because ac is being assigned without using ptr::write, so its old (uninitialized value) gets dropped. We also have to worry about f panicking, because a panic will cause accum to be dropped, which will cause each of its elements to be dropped, including uninitialized elements.

}
accum
}

/// Inplace version of `accumulate_axis`. See that method for more
/// documentation.
///
/// **panics** if the dimension of `self` along `axis` is 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary to panic in this case. We can add an explicit check for this case and return an empty array of the correct shape.

///
/// # Example
/// ```
/// use ndarray::{arr1, arr2, Axis};
/// use std::ops::Add;
///
/// let mut a = arr2(&[[1., 2.],
/// [3., 4.],
/// [5., 6.]]);
/// a.accumulate_axis_inplace(Axis(0), |a, b| a+b);
/// assert!(a
/// .all_close(&arr2(&[[1., 2.],
/// [4., 6.],
/// [9., 12.]]), 1e-12));
///
/// let mut b = arr1(&[1., 2., 3., 4.]);
/// b.accumulate_axis_inplace(Axis(0), |a, b| a*b);
/// assert!(b.all_close(&arr1(&[1., 2., 6., 24.]), 1e-12));
/// ```
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
where
A: Clone,
D: ::dimension::RemoveAxis,
F: FnMut(&A, &A) -> A,
S: ::data_traits::DataMut,
{
let mut states = self.subview(axis, 0).to_owned();
for mut self_i in self.axis_iter_mut(axis).skip(1) {
Zip::from(&mut states)
.and(&mut self_i)
.apply(|st, se| {*st = f(st, se); *se = st.clone();});
}
}

/// Traverse an axis, applying f to each element and returning the result.
/// Maintains a mutable copy of `initial_state` for each element in the subview
/// obtained by traversing `self`.
///
/// This function is similar to `accumulate_axis`, but allows for a different
/// output type.
///
/// # Example
///
/// ```
/// use ndarray::{arr2, Axis};
///
/// let a = arr2(&[[1., 2.],
/// [3., 4.],
/// [5., 6.]]);
/// let scanned = a.scan_axis(Axis(0), 0., |acc, x| {*acc += x; *acc as i32});
/// assert_eq!((scanned - arr2(&[[1, 2],
/// [4, 6],
/// [9, 12]])).mapv(i32::abs).scalar_sum(), 0);
/// ```
pub fn scan_axis<St, B, F>(&self, axis: Axis, initial_state: St, mut f: F)
-> Array<B, D>
where
B: Clone,
D: ::dimension::RemoveAxis,
F: FnMut(&mut St, &A) -> B,
St: Copy,
{
let mut accum = unsafe{
let mut v = Vec::with_capacity(self.len());
v.set_len(self.len());
Array::<B, D>::from_shape_vec_unchecked(self.dim(), v)
};
let mut states = Array::<St, _>::from_elem(self.dim.remove_axis(axis), initial_state);
for (mut accum_i, self_i) in accum.axis_iter_mut(axis).zip(self.axis_iter(axis)) {
Zip::from(&mut accum_i)
.and(&mut states)
.and(&self_i)
.apply(|ac, st, se| *ac = f(st, se));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in undefined behavior if B implements Drop, because ac is being assigned without using ptr::write, so its old (uninitialized value) gets dropped. We also have to worry about f panicking, because a panic will cause accum to be dropped, which will cause each of its elements to be dropped, including uninitialized elements.

}
accum
}
}