Skip to content

Commit 642a44c

Browse files
authored
Merge pull request #611 from jturner314/accumulate-axis-inplace
Add accumulate_axis_inplace method and implement NdProducer for RawArrayView/Mut
2 parents d4dd6f5 + 886b6a1 commit 642a44c

File tree

3 files changed

+218
-6
lines changed

3 files changed

+218
-6
lines changed

src/impl_methods.rs

+56
Original file line numberDiff line numberDiff line change
@@ -2231,4 +2231,60 @@ where
22312231
})
22322232
}
22332233
}
2234+
2235+
/// Iterates over pairs of consecutive elements along the axis.
2236+
///
2237+
/// The first argument to the closure is an element, and the second
2238+
/// argument is the next element along the axis. Iteration is guaranteed to
2239+
/// proceed in order along the specified axis, but in all other respects
2240+
/// the iteration order is unspecified.
2241+
///
2242+
/// # Example
2243+
///
2244+
/// For example, this can be used to compute the cumulative sum along an
2245+
/// axis:
2246+
///
2247+
/// ```
2248+
/// use ndarray::{array, Axis};
2249+
///
2250+
/// let mut arr = array![
2251+
/// [[1, 2], [3, 4], [5, 6]],
2252+
/// [[7, 8], [9, 10], [11, 12]],
2253+
/// ];
2254+
/// arr.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
2255+
/// assert_eq!(
2256+
/// arr,
2257+
/// array![
2258+
/// [[1, 2], [4, 6], [9, 12]],
2259+
/// [[7, 8], [16, 18], [27, 30]],
2260+
/// ],
2261+
/// );
2262+
/// ```
2263+
pub fn accumulate_axis_inplace<F>(&mut self, axis: Axis, mut f: F)
2264+
where
2265+
F: FnMut(&A, &mut A),
2266+
S: DataMut,
2267+
{
2268+
if self.len_of(axis) <= 1 {
2269+
return;
2270+
}
2271+
let mut curr = self.raw_view_mut(); // mut borrow of the array here
2272+
let mut prev = curr.raw_view(); // derive further raw views from the same borrow
2273+
prev.slice_axis_inplace(axis, Slice::from(..-1));
2274+
curr.slice_axis_inplace(axis, Slice::from(1..));
2275+
// This implementation relies on `Zip` iterating along `axis` in order.
2276+
Zip::from(prev).and(curr).apply(|prev, curr| unsafe {
2277+
// These pointer dereferences and borrows are safe because:
2278+
//
2279+
// 1. They're pointers to elements in the array.
2280+
//
2281+
// 2. `S: DataMut` guarantees that elements are safe to borrow
2282+
// mutably and that they don't alias.
2283+
//
2284+
// 3. The lifetimes of the borrows last only for the duration
2285+
// of the call to `f`, so aliasing across calls to `f`
2286+
// cannot occur.
2287+
f(&*prev, &mut *curr)
2288+
});
2289+
}
22342290
}

src/zip/mod.rs

+116-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ where
4747

4848
impl<S, D> ArrayBase<S, D>
4949
where
50-
S: Data,
50+
S: RawData,
5151
D: Dimension,
5252
{
5353
pub(crate) fn layout_impl(&self) -> Layout {
@@ -57,7 +57,7 @@ where
5757
} else {
5858
CORDER
5959
}
60-
} else if self.ndim() > 1 && self.t().is_standard_layout() {
60+
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
6161
FORDER
6262
} else {
6363
0
@@ -192,6 +192,14 @@ pub trait Offset: Copy {
192192
private_decl! {}
193193
}
194194

195+
impl<T> Offset for *const T {
196+
type Stride = isize;
197+
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
198+
self.offset(s * (index as isize))
199+
}
200+
private_impl! {}
201+
}
202+
195203
impl<T> Offset for *mut T {
196204
type Stride = isize;
197205
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
@@ -389,6 +397,112 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> {
389397
}
390398
}
391399

400+
impl<A, D: Dimension> NdProducer for RawArrayView<A, D> {
401+
type Item = *const A;
402+
type Dim = D;
403+
type Ptr = *const A;
404+
type Stride = isize;
405+
406+
private_impl! {}
407+
#[doc(hidden)]
408+
fn raw_dim(&self) -> Self::Dim {
409+
self.raw_dim()
410+
}
411+
412+
#[doc(hidden)]
413+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
414+
self.dim.equal(dim)
415+
}
416+
417+
#[doc(hidden)]
418+
fn as_ptr(&self) -> *const A {
419+
self.as_ptr()
420+
}
421+
422+
#[doc(hidden)]
423+
fn layout(&self) -> Layout {
424+
self.layout_impl()
425+
}
426+
427+
#[doc(hidden)]
428+
unsafe fn as_ref(&self, ptr: *const A) -> *const A {
429+
ptr
430+
}
431+
432+
#[doc(hidden)]
433+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A {
434+
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
435+
}
436+
437+
#[doc(hidden)]
438+
fn stride_of(&self, axis: Axis) -> isize {
439+
self.stride_of(axis)
440+
}
441+
442+
#[inline(always)]
443+
fn contiguous_stride(&self) -> Self::Stride {
444+
1
445+
}
446+
447+
#[doc(hidden)]
448+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
449+
self.split_at(axis, index)
450+
}
451+
}
452+
453+
impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> {
454+
type Item = *mut A;
455+
type Dim = D;
456+
type Ptr = *mut A;
457+
type Stride = isize;
458+
459+
private_impl! {}
460+
#[doc(hidden)]
461+
fn raw_dim(&self) -> Self::Dim {
462+
self.raw_dim()
463+
}
464+
465+
#[doc(hidden)]
466+
fn equal_dim(&self, dim: &Self::Dim) -> bool {
467+
self.dim.equal(dim)
468+
}
469+
470+
#[doc(hidden)]
471+
fn as_ptr(&self) -> *mut A {
472+
self.as_ptr() as _
473+
}
474+
475+
#[doc(hidden)]
476+
fn layout(&self) -> Layout {
477+
self.layout_impl()
478+
}
479+
480+
#[doc(hidden)]
481+
unsafe fn as_ref(&self, ptr: *mut A) -> *mut A {
482+
ptr
483+
}
484+
485+
#[doc(hidden)]
486+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
487+
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
488+
}
489+
490+
#[doc(hidden)]
491+
fn stride_of(&self, axis: Axis) -> isize {
492+
self.stride_of(axis)
493+
}
494+
495+
#[inline(always)]
496+
fn contiguous_stride(&self) -> Self::Stride {
497+
1
498+
}
499+
500+
#[doc(hidden)]
501+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
502+
self.split_at(axis, index)
503+
}
504+
}
505+
392506
/// Lock step function application across several arrays or other producers.
393507
///
394508
/// Zip allows matching several producers to each other elementwise and applying

tests/array.rs

+46-4
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ fn owned_array_with_stride() {
11051105

11061106
#[test]
11071107
fn owned_array_discontiguous() {
1108-
use ::std::iter::repeat;
1108+
use std::iter::repeat;
11091109
let v: Vec<_> = (0..12).flat_map(|x| repeat(x).take(2)).collect();
11101110
let dim = (3, 2, 2);
11111111
let strides = (8, 4, 2);
@@ -1118,9 +1118,9 @@ fn owned_array_discontiguous() {
11181118

11191119
#[test]
11201120
fn owned_array_discontiguous_drop() {
1121-
use ::std::cell::RefCell;
1122-
use ::std::collections::BTreeSet;
1123-
use ::std::rc::Rc;
1121+
use std::cell::RefCell;
1122+
use std::collections::BTreeSet;
1123+
use std::rc::Rc;
11241124

11251125
struct InsertOnDrop<T: Ord>(Rc<RefCell<BTreeSet<T>>>, Option<T>);
11261126
impl<T: Ord> Drop for InsertOnDrop<T> {
@@ -1959,6 +1959,48 @@ fn test_map_axis() {
19591959
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
19601960
}
19611961

1962+
#[test]
1963+
fn test_accumulate_axis_inplace_noop() {
1964+
let mut a = Array2::<u8>::zeros((0, 3));
1965+
a.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1966+
assert_eq!(a, Array2::zeros((0, 3)));
1967+
1968+
let mut a = Array2::<u8>::zeros((3, 1));
1969+
a.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
1970+
assert_eq!(a, Array2::zeros((3, 1)));
1971+
}
1972+
1973+
#[rustfmt::skip] // Allow block array formatting
1974+
#[test]
1975+
fn test_accumulate_axis_inplace_nonstandard_layout() {
1976+
let a = arr2(&[[1, 2, 3],
1977+
[4, 5, 6],
1978+
[7, 8, 9],
1979+
[10,11,12]]);
1980+
1981+
let mut a_t = a.clone().reversed_axes();
1982+
a_t.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1983+
assert_eq!(a_t, aview2(&[[1, 4, 7, 10],
1984+
[3, 9, 15, 21],
1985+
[6, 15, 24, 33]]));
1986+
1987+
let mut a0 = a.clone();
1988+
a0.invert_axis(Axis(0));
1989+
a0.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
1990+
assert_eq!(a0, aview2(&[[10, 11, 12],
1991+
[17, 19, 21],
1992+
[21, 24, 27],
1993+
[22, 26, 30]]));
1994+
1995+
let mut a1 = a.clone();
1996+
a1.invert_axis(Axis(1));
1997+
a1.accumulate_axis_inplace(Axis(1), |&prev, curr| *curr += prev);
1998+
assert_eq!(a1, aview2(&[[3, 5, 6],
1999+
[6, 11, 15],
2000+
[9, 17, 24],
2001+
[12, 23, 33]]));
2002+
}
2003+
19622004
#[test]
19632005
fn test_to_vec() {
19642006
let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);

0 commit comments

Comments
 (0)