Skip to content

Commit 4e01b15

Browse files
Fix overflow for interpolate::midpoint (#28)
Breaking changes: * Use NumOps when arithmetic is required for Interpolate implementations
1 parent dcc0cf1 commit 4e01b15

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

src/maybe_nan/impl_not_none.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::NotNone;
22
use num_traits::{FromPrimitive, ToPrimitive};
33
use std::cmp;
44
use std::fmt;
5-
use std::ops::{Add, Deref, DerefMut, Div, Mul, Sub};
5+
use std::ops::{Add, Deref, DerefMut, Div, Mul, Sub, Rem};
66

77
impl<T> Deref for NotNone<T> {
88
type Target = T;
@@ -96,6 +96,14 @@ impl<T: Div> Div for NotNone<T> {
9696
}
9797
}
9898

99+
impl<T: Rem> Rem for NotNone<T> {
100+
type Output = NotNone<T::Output>;
101+
#[inline]
102+
fn rem(self, rhs: Self) -> Self::Output {
103+
self.map(|v| v.rem(rhs.unwrap()))
104+
}
105+
}
106+
99107
impl<T: ToPrimitive> ToPrimitive for NotNone<T> {
100108
#[inline]
101109
fn to_isize(&self) -> Option<isize> {

src/quantile.rs

+11-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ use {MaybeNan, MaybeNanExt, Sort1dExt};
88
pub mod interpolate {
99
use ndarray::azip;
1010
use ndarray::prelude::*;
11-
use num_traits::{FromPrimitive, ToPrimitive};
12-
use std::ops::{Add, Div};
11+
use num_traits::{FromPrimitive, ToPrimitive, NumOps};
1312

1413
/// Used to provide an interpolation strategy to [`quantile_axis_mut`].
1514
///
@@ -116,7 +115,7 @@ pub mod interpolate {
116115

117116
impl<T> Interpolate<T> for Midpoint
118117
where
119-
T: Add<T, Output = T> + Div<T, Output = T> + Clone + FromPrimitive,
118+
T: NumOps + Clone + FromPrimitive,
120119
{
121120
fn needs_lower(_q: f64, _len: usize) -> bool {
122121
true
@@ -134,13 +133,20 @@ pub mod interpolate {
134133
D: Dimension,
135134
{
136135
let denom = T::from_u8(2).unwrap();
137-
(lower.unwrap() + higher.unwrap()).mapv_into(|x| x / denom.clone())
136+
let mut lower = lower.unwrap();
137+
let higher = higher.unwrap();
138+
azip!(
139+
mut lower, ref higher in {
140+
*lower = lower.clone() + (higher.clone() - lower.clone()) / denom.clone()
141+
}
142+
);
143+
lower
138144
}
139145
}
140146

141147
impl<T> Interpolate<T> for Linear
142148
where
143-
T: Add<T, Output = T> + Clone + FromPrimitive + ToPrimitive,
149+
T: NumOps + Clone + FromPrimitive + ToPrimitive,
144150
{
145151
fn needs_lower(_q: f64, _len: usize) -> bool {
146152
true

tests/quantile.rs

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use ndarray::prelude::*;
66
use ndarray_stats::{
77
interpolate::{Higher, Linear, Lower, Midpoint, Nearest},
88
QuantileExt,
9+
Quantile1dExt,
910
};
1011

1112
#[test]
@@ -148,3 +149,13 @@ fn test_quantile_axis_skipnan_mut_linear_opt_i32() {
148149
assert_eq!(q[0], Some(3));
149150
assert!(q[1].is_none());
150151
}
152+
153+
#[test]
154+
fn test_midpoint_overflow() {
155+
// Regression test
156+
// This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2
157+
let mut a: Array1<u8> = array![129, 130, 130, 131];
158+
let median = a.quantile_mut::<Midpoint>(0.5).unwrap();
159+
let expected_median = 130;
160+
assert_eq!(median, expected_median);
161+
}

0 commit comments

Comments
 (0)