Skip to content

Commit 86e5ca4

Browse files
Histogram error handling (#25)
* Use Option as return type where things might fail * Test suite aligned with docs * Equispaced does not panic anymore * Fixed some tests * Fixed FD tests * Fixed wrong condition in IF * Fixed wrong test * Added new test for EquiSpaced and fixed old one * Fixed doc tests * Fix docs. * Fix docs. * Fix docs. * Fmt * Create StrategyError * Fmt * Return Result. Fix Equispaced, Sqrt and Rice * Fix Rice * Fixed Sturges * Fix strategies * Fix match * Tests compile * Fix assertion * Fmt * Add more error types * Rename StrategyError to BinsBuildError * Make GridBuilder::from_array return Result This is nice because it doesn't lose information. (Returning an `Option` combines the two error variants into a single case.) * Make BinsBuildError enum non-exhaustive Once the `#[non_exhaustive]` attribute is stable, we should replace the hidden enum variant with that attribute on the enum. * Use lazy OR operator. * Use lazy OR operator.
1 parent d838ee7 commit 86e5ca4

File tree

10 files changed

+423
-211
lines changed

10 files changed

+423
-211
lines changed

src/entropy.rs

+48-49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Information theory (e.g. entropy, KL divergence, etc.).
2-
use crate::errors::ShapeMismatch;
2+
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
33
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
44
use num_traits::Float;
55

@@ -19,7 +19,7 @@ where
1919
/// i=1
2020
/// ```
2121
///
22-
/// If the array is empty, `None` is returned.
22+
/// If the array is empty, `Err(EmptyInput)` is returned.
2323
///
2424
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
2525
///
@@ -38,7 +38,7 @@ where
3838
///
3939
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
4040
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
41-
fn entropy(&self) -> Option<A>
41+
fn entropy(&self) -> Result<A, EmptyInput>
4242
where
4343
A: Float;
4444

@@ -53,8 +53,9 @@ where
5353
/// i=1
5454
/// ```
5555
///
56-
/// If the arrays are empty, Ok(`None`) is returned.
57-
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
56+
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
57+
/// If the array shapes are not identical,
58+
/// `Err(MultiInputError::ShapeMismatch)` is returned.
5859
///
5960
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
6061
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
@@ -73,7 +74,7 @@ where
7374
///
7475
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
7576
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
76-
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
77+
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
7778
where
7879
S2: Data<Elem = A>,
7980
A: Float;
@@ -89,8 +90,9 @@ where
8990
/// i=1
9091
/// ```
9192
///
92-
/// If the arrays are empty, Ok(`None`) is returned.
93-
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
93+
/// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
94+
/// If the array shapes are not identical,
95+
/// `Err(MultiInputError::ShapeMismatch)` is returned.
9496
///
9597
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
9698
/// is a panic cause for `A`.
@@ -114,7 +116,7 @@ where
114116
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
115117
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
116118
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
117-
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
119+
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
118120
where
119121
S2: Data<Elem = A>,
120122
A: Float;
@@ -125,14 +127,14 @@ where
125127
S: Data<Elem = A>,
126128
D: Dimension,
127129
{
128-
fn entropy(&self) -> Option<A>
130+
fn entropy(&self) -> Result<A, EmptyInput>
129131
where
130132
A: Float,
131133
{
132134
if self.len() == 0 {
133-
None
135+
Err(EmptyInput)
134136
} else {
135-
let entropy = self
137+
let entropy = -self
136138
.mapv(|x| {
137139
if x == A::zero() {
138140
A::zero()
@@ -141,23 +143,24 @@ where
141143
}
142144
})
143145
.sum();
144-
Some(-entropy)
146+
Ok(entropy)
145147
}
146148
}
147149

148-
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
150+
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
149151
where
150152
A: Float,
151153
S2: Data<Elem = A>,
152154
{
153155
if self.len() == 0 {
154-
return Ok(None);
156+
return Err(MultiInputError::EmptyInput);
155157
}
156158
if self.shape() != q.shape() {
157159
return Err(ShapeMismatch {
158160
first_shape: self.shape().to_vec(),
159161
second_shape: q.shape().to_vec(),
160-
});
162+
}
163+
.into());
161164
}
162165

163166
let mut temp = Array::zeros(self.raw_dim());
@@ -174,22 +177,23 @@ where
174177
}
175178
});
176179
let kl_divergence = -temp.sum();
177-
Ok(Some(kl_divergence))
180+
Ok(kl_divergence)
178181
}
179182

180-
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
183+
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<A, MultiInputError>
181184
where
182185
S2: Data<Elem = A>,
183186
A: Float,
184187
{
185188
if self.len() == 0 {
186-
return Ok(None);
189+
return Err(MultiInputError::EmptyInput);
187190
}
188191
if self.shape() != q.shape() {
189192
return Err(ShapeMismatch {
190193
first_shape: self.shape().to_vec(),
191194
second_shape: q.shape().to_vec(),
192-
});
195+
}
196+
.into());
193197
}
194198

195199
let mut temp = Array::zeros(self.raw_dim());
@@ -206,15 +210,15 @@ where
206210
}
207211
});
208212
let cross_entropy = -temp.sum();
209-
Ok(Some(cross_entropy))
213+
Ok(cross_entropy)
210214
}
211215
}
212216

213217
#[cfg(test)]
214218
mod tests {
215219
use super::EntropyExt;
216220
use approx::assert_abs_diff_eq;
217-
use errors::ShapeMismatch;
221+
use errors::{EmptyInput, MultiInputError};
218222
use ndarray::{array, Array1};
219223
use noisy_float::types::n64;
220224
use std::f64;
@@ -228,7 +232,7 @@ mod tests {
228232
#[test]
229233
fn test_entropy_with_empty_array_of_floats() {
230234
let a: Array1<f64> = array![];
231-
assert!(a.entropy().is_none());
235+
assert_eq!(a.entropy(), Err(EmptyInput));
232236
}
233237

234238
#[test]
@@ -251,13 +255,13 @@ mod tests {
251255
}
252256

253257
#[test]
254-
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
258+
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), MultiInputError> {
255259
let a = array![f64::NAN, 1.];
256260
let b = array![2., 1.];
257-
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
258-
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
259-
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
260-
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
261+
assert!(a.cross_entropy(&b)?.is_nan());
262+
assert!(b.cross_entropy(&a)?.is_nan());
263+
assert!(a.kl_divergence(&b)?.is_nan());
264+
assert!(b.kl_divergence(&a)?.is_nan());
261265
Ok(())
262266
}
263267

@@ -284,20 +288,19 @@ mod tests {
284288
}
285289

286290
#[test]
287-
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
291+
fn test_cross_entropy_and_kl_with_empty_array_of_floats() {
288292
let p: Array1<f64> = array![];
289293
let q: Array1<f64> = array![];
290-
assert!(p.cross_entropy(&q)?.is_none());
291-
assert!(p.kl_divergence(&q)?.is_none());
292-
Ok(())
294+
assert!(p.cross_entropy(&q).unwrap_err().is_empty_input());
295+
assert!(p.kl_divergence(&q).unwrap_err().is_empty_input());
293296
}
294297

295298
#[test]
296-
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
299+
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), MultiInputError> {
297300
let p = array![1.];
298301
let q = array![-1.];
299-
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
300-
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
302+
let cross_entropy: f64 = p.cross_entropy(&q)?;
303+
let kl_divergence: f64 = p.kl_divergence(&q)?;
301304
assert!(cross_entropy.is_nan());
302305
assert!(kl_divergence.is_nan());
303306
Ok(())
@@ -320,26 +323,26 @@ mod tests {
320323
}
321324

322325
#[test]
323-
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
326+
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), MultiInputError> {
324327
let p = array![0., 0.];
325328
let q = array![0., 0.5];
326-
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
327-
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
329+
assert_eq!(p.cross_entropy(&q)?, 0.);
330+
assert_eq!(p.kl_divergence(&q)?, 0.);
328331
Ok(())
329332
}
330333

331334
#[test]
332335
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
333-
) -> Result<(), ShapeMismatch> {
336+
) -> Result<(), MultiInputError> {
334337
let p = array![0.5, 0.5];
335338
let mut q = array![0.5, 0.];
336-
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
337-
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
339+
assert_eq!(p.cross_entropy(&q.view_mut())?, f64::INFINITY);
340+
assert_eq!(p.kl_divergence(&q.view_mut())?, f64::INFINITY);
338341
Ok(())
339342
}
340343

341344
#[test]
342-
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
345+
fn test_cross_entropy() -> Result<(), MultiInputError> {
343346
// Arrays of probability values - normalized and positive.
344347
let p: Array1<f64> = array![
345348
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
@@ -356,16 +359,12 @@ mod tests {
356359
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
357360
let expected_cross_entropy = 3.385347705020779;
358361

359-
assert_abs_diff_eq!(
360-
p.cross_entropy(&q)?.unwrap(),
361-
expected_cross_entropy,
362-
epsilon = 1e-6
363-
);
362+
assert_abs_diff_eq!(p.cross_entropy(&q)?, expected_cross_entropy, epsilon = 1e-6);
364363
Ok(())
365364
}
366365

367366
#[test]
368-
fn test_kl() -> Result<(), ShapeMismatch> {
367+
fn test_kl() -> Result<(), MultiInputError> {
369368
// Arrays of probability values - normalized and positive.
370369
let p: Array1<f64> = array![
371370
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
@@ -390,7 +389,7 @@ mod tests {
390389
// Computed using scipy.stats.entropy(p, q)
391390
let expected_kl = 0.3555862567800096;
392391

393-
assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
392+
assert_abs_diff_eq!(p.kl_divergence(&q)?, expected_kl, epsilon = 1e-6);
394393
Ok(())
395394
}
396395
}

src/errors.rs

+91-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,50 @@
22
use std::error::Error;
33
use std::fmt;
44

5-
#[derive(Debug)]
5+
/// An error that indicates that the input array was empty.
6+
#[derive(Clone, Debug, Eq, PartialEq)]
7+
pub struct EmptyInput;
8+
9+
impl fmt::Display for EmptyInput {
10+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
11+
write!(f, "Empty input.")
12+
}
13+
}
14+
15+
impl Error for EmptyInput {}
16+
17+
/// An error computing a minimum/maximum value.
18+
#[derive(Clone, Debug, Eq, PartialEq)]
19+
pub enum MinMaxError {
20+
/// The input was empty.
21+
EmptyInput,
22+
/// The ordering between a tested pair of values was undefined.
23+
UndefinedOrder,
24+
}
25+
26+
impl fmt::Display for MinMaxError {
27+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28+
match self {
29+
MinMaxError::EmptyInput => write!(f, "Empty input."),
30+
MinMaxError::UndefinedOrder => {
31+
write!(f, "Undefined ordering between a tested pair of values.")
32+
}
33+
}
34+
}
35+
}
36+
37+
impl Error for MinMaxError {}
38+
39+
impl From<EmptyInput> for MinMaxError {
40+
fn from(_: EmptyInput) -> MinMaxError {
41+
MinMaxError::EmptyInput
42+
}
43+
}
44+
645
/// An error used by methods and functions that take two arrays as argument and
746
/// expect them to have exactly the same shape
847
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
48+
#[derive(Clone, Debug)]
949
pub struct ShapeMismatch {
1050
pub first_shape: Vec<usize>,
1151
pub second_shape: Vec<usize>,
@@ -22,3 +62,53 @@ impl fmt::Display for ShapeMismatch {
2262
}
2363

2464
impl Error for ShapeMismatch {}
65+
66+
/// An error for methods that take multiple non-empty array inputs.
67+
#[derive(Clone, Debug)]
68+
pub enum MultiInputError {
69+
/// One or more of the arrays were empty.
70+
EmptyInput,
71+
/// The arrays did not have the same shape.
72+
ShapeMismatch(ShapeMismatch),
73+
}
74+
75+
impl MultiInputError {
76+
/// Returns whether `self` is the `EmptyInput` variant.
77+
pub fn is_empty_input(&self) -> bool {
78+
match self {
79+
MultiInputError::EmptyInput => true,
80+
_ => false,
81+
}
82+
}
83+
84+
/// Returns whether `self` is the `ShapeMismatch` variant.
85+
pub fn is_shape_mismatch(&self) -> bool {
86+
match self {
87+
MultiInputError::ShapeMismatch(_) => true,
88+
_ => false,
89+
}
90+
}
91+
}
92+
93+
impl fmt::Display for MultiInputError {
94+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95+
match self {
96+
MultiInputError::EmptyInput => write!(f, "Empty input."),
97+
MultiInputError::ShapeMismatch(e) => write!(f, "Shape mismatch: {}", e),
98+
}
99+
}
100+
}
101+
102+
impl Error for MultiInputError {}
103+
104+
impl From<EmptyInput> for MultiInputError {
105+
fn from(_: EmptyInput) -> Self {
106+
MultiInputError::EmptyInput
107+
}
108+
}
109+
110+
impl From<ShapeMismatch> for MultiInputError {
111+
fn from(err: ShapeMismatch) -> Self {
112+
MultiInputError::ShapeMismatch(err)
113+
}
114+
}

0 commit comments

Comments
 (0)