Skip to content

Commit fce6034

Browse files
authored
Add diff method as an equivalent to numpy.diff (#1437)
* implement forward finite differneces on arrays * implement tests for the method * remove some heap allocations
1 parent c7ebd35 commit fce6034

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

src/numeric/impl_numeric.rs

+57-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
use num_traits::Float;
1111
use num_traits::One;
1212
use num_traits::{FromPrimitive, Zero};
13-
use std::ops::{Add, Div, Mul};
13+
use std::ops::{Add, Div, Mul, Sub};
1414

1515
use crate::imp_prelude::*;
1616
use crate::numeric_util;
17+
use crate::Slice;
1718

1819
/// # Numerical Methods for Arrays
1920
impl<A, S, D> ArrayBase<S, D>
@@ -437,4 +438,59 @@ where
437438
{
438439
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
439440
}
441+
442+
/// Calculates the (forward) finite differences of order `n`, along the `axis`.
443+
/// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
444+
///
445+
/// For `n>=2`, the process is iterated:
446+
/// ```
447+
/// use ndarray::{array, Axis};
448+
/// let arr = array![1.0, 2.0, 5.0];
449+
/// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
450+
/// ```
451+
/// **Panics** if `axis` is out of bounds
452+
///
453+
/// **Panics** if `n` is too big / the array is to short:
454+
/// ```should_panic
455+
/// use ndarray::{array, Axis};
456+
/// array![1.0, 2.0, 3.0].diff(10, Axis(0));
457+
/// ```
458+
pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
459+
where A: Sub<A, Output = A> + Zero + Clone
460+
{
461+
if n == 0 {
462+
return self.to_owned();
463+
}
464+
assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
465+
assert!(
466+
n < self.shape()[axis.0],
467+
"The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
468+
n + 1,
469+
self.shape()[axis.0]
470+
);
471+
472+
let mut inp = self.to_owned();
473+
let mut out = Array::zeros({
474+
let mut inp_dim = self.raw_dim();
475+
// inp_dim[axis.0] >= 1 as per the 2nd assertion.
476+
inp_dim[axis.0] -= 1;
477+
inp_dim
478+
});
479+
for _ in 0..n {
480+
let head = inp.slice_axis(axis, Slice::from(..-1));
481+
let tail = inp.slice_axis(axis, Slice::from(1..));
482+
483+
azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());
484+
485+
// feed the output as the input to the next iteration
486+
std::mem::swap(&mut inp, &mut out);
487+
488+
// adjust the new output array width along `axis`.
489+
// Current situation: width of `inp`: k, `out`: k+1
490+
// needed width: `inp`: k, `out`: k-1
491+
// slice is possible, since k >= 1.
492+
out.slice_axis_inplace(axis, Slice::from(..-2));
493+
}
494+
inp
495+
}
440496
}

tests/numeric.rs

+68
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,71 @@ fn std_axis_empty_axis()
336336
assert_eq!(v.shape(), &[2]);
337337
v.mapv(|x| assert!(x.is_nan()));
338338
}
339+
340+
#[test]
341+
fn diff_1d_order1()
342+
{
343+
let data = array![1.0, 2.0, 4.0, 7.0];
344+
let expected = array![1.0, 2.0, 3.0];
345+
assert_eq!(data.diff(1, Axis(0)), expected);
346+
}
347+
348+
#[test]
349+
fn diff_1d_order2()
350+
{
351+
let data = array![1.0, 2.0, 4.0, 7.0];
352+
assert_eq!(
353+
data.diff(2, Axis(0)),
354+
data.diff(1, Axis(0)).diff(1, Axis(0))
355+
);
356+
}
357+
358+
#[test]
359+
fn diff_1d_order3()
360+
{
361+
let data = array![1.0, 2.0, 4.0, 7.0];
362+
assert_eq!(
363+
data.diff(3, Axis(0)),
364+
data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0))
365+
);
366+
}
367+
368+
#[test]
369+
fn diff_2d_order1_ax0()
370+
{
371+
let data = array![
372+
[1.0, 2.0, 4.0, 7.0],
373+
[1.0, 3.0, 6.0, 6.0],
374+
[1.5, 3.5, 5.5, 5.5]
375+
];
376+
let expected = array![[0.0, 1.0, 2.0, -1.0], [0.5, 0.5, -0.5, -0.5]];
377+
assert_eq!(data.diff(1, Axis(0)), expected);
378+
}
379+
380+
#[test]
381+
fn diff_2d_order1_ax1()
382+
{
383+
let data = array![
384+
[1.0, 2.0, 4.0, 7.0],
385+
[1.0, 3.0, 6.0, 6.0],
386+
[1.5, 3.5, 5.5, 5.5]
387+
];
388+
let expected = array![[1.0, 2.0, 3.0], [2.0, 3.0, 0.0], [2.0, 2.0, 0.0]];
389+
assert_eq!(data.diff(1, Axis(1)), expected);
390+
}
391+
392+
#[test]
393+
#[should_panic]
394+
fn diff_panic_n_too_big()
395+
{
396+
let data = array![1.0, 2.0, 4.0, 7.0];
397+
data.diff(10, Axis(0));
398+
}
399+
400+
#[test]
401+
#[should_panic]
402+
fn diff_panic_axis_out_of_bounds()
403+
{
404+
let data = array![1, 2, 4, 7];
405+
data.diff(1, Axis(2));
406+
}

0 commit comments

Comments
 (0)