Skip to content

Commit d838ee7

Browse files
phunglesonjturner314
authored andcommitted
Add argmin/max_skipnan and indexed_fold_skipnan (#33)
* Implement argmin_skipnan * Implement argmax_skipnan * Loosen the rule for argmin max related methods * Make returning code clearer * Add quickcheck for argmin_skipnan, argmax_skipnan * Use `fold` instead of `for` * Add indexed_fold_skipnan to MaybeNanExt * Impl argmin/max_skipnan using indexed_fold_skipnan * Fix argmin/max_skipnan quickcheck tests The old tests were incorrect because `min`/`max` return `None` when there are *any* NaN values (or the array is empty), while `argmin/max_skipnan` should return `None` only when *all* the values are NaNs (or the array is empty). This wasn't caught earlier because the `quickcheck::Arbitrary` implementation for `f32` generates only finite values. To make sure the behavior with NaN values is properly tested, the element type in the test has been changed to `Option<i32>`. * Replace min/max.map with if for clarity * Add () to make the match clearer
1 parent 7df0728 commit d838ee7

File tree

3 files changed

+186
-0
lines changed

3 files changed

+186
-0
lines changed

src/maybe_nan/mod.rs

+23
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ where
241241
A: 'a,
242242
F: FnMut(B, &'a A::NotNan) -> B;
243243

244+
/// Traverse the non-NaN elements and their indices and apply a fold,
245+
/// returning the resulting value.
246+
///
247+
/// Elements are visited in arbitrary order.
248+
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
249+
where
250+
A: 'a,
251+
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;
252+
244253
/// Visit each non-NaN element in the array by calling `f` on each element.
245254
///
246255
/// Elements are visited in arbitrary order.
@@ -302,6 +311,20 @@ where
302311
})
303312
}
304313

314+
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
315+
where
316+
A: 'a,
317+
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
318+
{
319+
self.indexed_iter().fold(init, |acc, (idx, elem)| {
320+
if let Some(not_nan) = elem.try_as_not_nan() {
321+
f(acc, (idx, not_nan))
322+
} else {
323+
acc
324+
}
325+
})
326+
}
327+
305328
fn visit_skipnan<'a, F>(&'a self, mut f: F)
306329
where
307330
A: 'a,

src/quantile.rs

+98
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,33 @@ where
211211
where
212212
A: PartialOrd;
213213

214+
/// Finds the index of the minimum value of the array skipping NaN values.
215+
///
216+
/// Returns `None` if the array is empty or none of the values in the array
217+
/// are non-NaN values.
218+
///
219+
/// Even if there are multiple (equal) elements that are minima, only one
220+
/// index is returned. (Which one is returned is unspecified and may depend
221+
/// on the memory layout of the array.)
222+
///
223+
/// # Example
224+
///
225+
/// ```
226+
/// extern crate ndarray;
227+
/// extern crate ndarray_stats;
228+
///
229+
/// use ndarray::array;
230+
/// use ndarray_stats::QuantileExt;
231+
///
232+
/// let a = array![[::std::f64::NAN, 3., 5.],
233+
/// [2., 0., 6.]];
234+
/// assert_eq!(a.argmin_skipnan(), Some((1, 1)));
235+
/// ```
236+
fn argmin_skipnan(&self) -> Option<D::Pattern>
237+
where
238+
A: MaybeNan,
239+
A::NotNan: Ord;
240+
214241
/// Finds the elementwise minimum of the array.
215242
///
216243
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -269,6 +296,33 @@ where
269296
where
270297
A: PartialOrd;
271298

299+
/// Finds the index of the maximum value of the array skipping NaN values.
300+
///
301+
/// Returns `None` if the array is empty or none of the values in the array
302+
/// are non-NaN values.
303+
///
304+
/// Even if there are multiple (equal) elements that are maxima, only one
305+
/// index is returned. (Which one is returned is unspecified and may depend
306+
/// on the memory layout of the array.)
307+
///
308+
/// # Example
309+
///
310+
/// ```
311+
/// extern crate ndarray;
312+
/// extern crate ndarray_stats;
313+
///
314+
/// use ndarray::array;
315+
/// use ndarray_stats::QuantileExt;
316+
///
317+
/// let a = array![[::std::f64::NAN, 3., 5.],
318+
/// [2., 0., 6.]];
319+
/// assert_eq!(a.argmax_skipnan(), Some((1, 2)));
320+
/// ```
321+
fn argmax_skipnan(&self) -> Option<D::Pattern>
322+
where
323+
A: MaybeNan,
324+
A::NotNan: Ord;
325+
272326
/// Finds the elementwise maximum of the array.
273327
///
274328
/// Returns `None` if any of the pairwise orderings tested by the function
@@ -369,6 +423,28 @@ where
369423
Some(current_pattern_min)
370424
}
371425

426+
fn argmin_skipnan(&self) -> Option<D::Pattern>
427+
where
428+
A: MaybeNan,
429+
A::NotNan: Ord,
430+
{
431+
let mut pattern_min = D::zeros(self.ndim()).into_pattern();
432+
let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| {
433+
Some(match current_min {
434+
Some(m) if (m <= elem) => m,
435+
_ => {
436+
pattern_min = pattern;
437+
elem
438+
}
439+
})
440+
});
441+
if min.is_some() {
442+
Some(pattern_min)
443+
} else {
444+
None
445+
}
446+
}
447+
372448
fn min(&self) -> Option<&A>
373449
where
374450
A: PartialOrd,
@@ -411,6 +487,28 @@ where
411487
Some(current_pattern_max)
412488
}
413489

490+
fn argmax_skipnan(&self) -> Option<D::Pattern>
491+
where
492+
A: MaybeNan,
493+
A::NotNan: Ord,
494+
{
495+
let mut pattern_max = D::zeros(self.ndim()).into_pattern();
496+
let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| {
497+
Some(match current_max {
498+
Some(m) if m >= elem => m,
499+
_ => {
500+
pattern_max = pattern;
501+
elem
502+
}
503+
})
504+
});
505+
if max.is_some() {
506+
Some(pattern_max)
507+
} else {
508+
None
509+
}
510+
}
511+
414512
fn max(&self) -> Option<&A>
415513
where
416514
A: PartialOrd,

tests/quantile.rs

+65
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,37 @@ quickcheck! {
3232
}
3333
}
3434

35+
#[test]
36+
fn test_argmin_skipnan() {
37+
let a = array![[1., 5., 3.], [2., 0., 6.]];
38+
assert_eq!(a.argmin_skipnan(), Some((1, 1)));
39+
40+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
41+
assert_eq!(a.argmin_skipnan(), Some((0, 0)));
42+
43+
let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]];
44+
assert_eq!(a.argmin_skipnan(), Some((1, 0)));
45+
46+
let a: Array2<f64> = array![[], []];
47+
assert_eq!(a.argmin_skipnan(), None);
48+
49+
let a = arr2(&[[::std::f64::NAN; 2]; 2]);
50+
assert_eq!(a.argmin_skipnan(), None);
51+
}
52+
53+
quickcheck! {
54+
fn argmin_skipnan_matches_min_skipnan(data: Vec<Option<i32>>) -> bool {
55+
let a = Array1::from(data);
56+
let min = a.min_skipnan();
57+
let argmin = a.argmin_skipnan();
58+
if min.is_none() {
59+
argmin == None
60+
} else {
61+
a[argmin.unwrap()] == *min
62+
}
63+
}
64+
}
65+
3566
#[test]
3667
fn test_min() {
3768
let a = array![[1, 5, 3], [2, 0, 6]];
@@ -81,6 +112,40 @@ quickcheck! {
81112
}
82113
}
83114

115+
#[test]
116+
fn test_argmax_skipnan() {
117+
let a = array![[1., 5., 3.], [2., 0., 6.]];
118+
assert_eq!(a.argmax_skipnan(), Some((1, 2)));
119+
120+
let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]];
121+
assert_eq!(a.argmax_skipnan(), Some((0, 1)));
122+
123+
let a = array![
124+
[::std::f64::NAN, ::std::f64::NAN, 3.],
125+
[2., ::std::f64::NAN, 6.]
126+
];
127+
assert_eq!(a.argmax_skipnan(), Some((1, 2)));
128+
129+
let a: Array2<f64> = array![[], []];
130+
assert_eq!(a.argmax_skipnan(), None);
131+
132+
let a = arr2(&[[::std::f64::NAN; 2]; 2]);
133+
assert_eq!(a.argmax_skipnan(), None);
134+
}
135+
136+
quickcheck! {
137+
fn argmax_skipnan_matches_max_skipnan(data: Vec<Option<i32>>) -> bool {
138+
let a = Array1::from(data);
139+
let max = a.max_skipnan();
140+
let argmax = a.argmax_skipnan();
141+
if max.is_none() {
142+
argmax == None
143+
} else {
144+
a[argmax.unwrap()] == *max
145+
}
146+
}
147+
}
148+
84149
#[test]
85150
fn test_max() {
86151
let a = array![[1, 5, 7], [2, 0, 6]];

0 commit comments

Comments
 (0)