diff --git a/src/entropy.rs b/src/entropy.rs new file mode 100644 index 00000000..8daac6d8 --- /dev/null +++ b/src/entropy.rs @@ -0,0 +1,396 @@ +//! Information theory (e.g. entropy, KL divergence, etc.). +use crate::errors::ShapeMismatch; +use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; +use num_traits::Float; + +/// Extension trait for `ArrayBase` providing methods +/// to compute information theory quantities +/// (e.g. entropy, Kullback–Leibler divergence, etc.). +pub trait EntropyExt +where + S: Data, + D: Dimension, +{ + /// Computes the [entropy] *S* of the array values, defined as + /// + /// ```text + /// n + /// S = - ∑ xᵢ ln(xᵢ) + /// i=1 + /// ``` + /// + /// If the array is empty, `None` is returned. + /// + /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). + /// + /// ## Remarks + /// + /// The entropy is a measure used in [Information Theory] + /// to describe a probability distribution: it only make sense + /// when the array values sum to 1, with each entry between + /// 0 and 1 (extremes included). + /// + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). + /// + /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. + /// + /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + fn entropy(&self) -> Option + where + A: Float; + + /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays, + /// where `self`=*p*. + /// + /// The Kullback-Leibler divergence is defined as: + /// + /// ```text + /// n + /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ) + /// i=1 + /// ``` + /// + /// If the arrays are empty, Ok(`None`) is returned. + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. + /// + /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing + /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. + /// + /// ## Remarks + /// + /// The Kullback-Leibler divergence is a measure used in [Information Theory] + /// to describe the relationship between two probability distribution: it only make sense + /// when each array sums to 1 with entries between 0 and 1 (extremes included). + /// + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). + /// + /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. + /// + /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> + where + S2: Data, + A: Float; + + /// Computes the [cross entropy] *H(p,q)* between two arrays, + /// where `self`=*p*. + /// + /// The cross entropy is defined as: + /// + /// ```text + /// n + /// H(p,q) = - ∑ pᵢ ln(qᵢ) + /// i=1 + /// ``` + /// + /// If the arrays are empty, Ok(`None`) is returned. + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. + /// + /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number + /// is a panic cause for `A`. + /// + /// ## Remarks + /// + /// The cross entropy is a measure used in [Information Theory] + /// to describe the relationship between two probability distributions: it only makes sense + /// when each array sums to 1 with entries between 0 and 1 (extremes included). + /// + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). + /// + /// The cross entropy is often used as an objective/loss function in + /// [optimization problems], including [machine learning]. + /// + /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0. + /// + /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method + /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression + fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> + where + S2: Data, + A: Float; +} + +impl EntropyExt for ArrayBase +where + S: Data, + D: Dimension, +{ + fn entropy(&self) -> Option + where + A: Float, + { + if self.len() == 0 { + None + } else { + let entropy = self + .mapv(|x| { + if x == A::zero() { + A::zero() + } else { + x * x.ln() + } + }) + .sum(); + Some(-entropy) + } + } + + fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> + where + A: Float, + S2: Data, + { + if self.len() == 0 { + return Ok(None); + } + if self.shape() != q.shape() { + return Err(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: q.shape().to_vec(), + }); + } + + let mut temp = Array::zeros(self.raw_dim()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * (q / p).ln() + } + } + }); + let kl_divergence = -temp.sum(); + Ok(Some(kl_divergence)) + } + + fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> + where + S2: Data, + A: Float, + { + if self.len() == 0 { + return Ok(None); + } + if self.shape() != q.shape() { + return Err(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: q.shape().to_vec(), + }); + } + + let mut temp = Array::zeros(self.raw_dim()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * q.ln() + } + } + }); + let cross_entropy = -temp.sum(); + Ok(Some(cross_entropy)) + } +} + +#[cfg(test)] +mod tests { + use super::EntropyExt; + use approx::assert_abs_diff_eq; + use errors::ShapeMismatch; + use ndarray::{array, Array1}; + use noisy_float::types::n64; + use std::f64; + + #[test] + fn test_entropy_with_nan_values() { + let a = array![f64::NAN, 1.]; + assert!(a.entropy().unwrap().is_nan()); + } + + #[test] + fn test_entropy_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert!(a.entropy().is_none()); + } + + #[test] + fn test_entropy_with_array_of_floats() { + // Array of probability values - normalized and positive. + let a: Array1 = array![ + 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396, + 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418, + 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668, + 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495, + 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, + 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634, + 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204, + 0.01866295, + ]; + // Computed using scipy.stats.entropy + let expected_entropy = 3.721606155686918; + + assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); + } + + #[test] + fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> { + let a = array![f64::NAN, 1.]; + let b = array![2., 1.]; + assert!(a.cross_entropy(&b)?.unwrap().is_nan()); + assert!(b.cross_entropy(&a)?.unwrap().is_nan()); + assert!(a.kl_divergence(&b)?.unwrap().is_nan()); + assert!(b.kl_divergence(&a)?.unwrap().is_nan()); + Ok(()) + } + + #[test] + fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { + let p = array![f64::NAN, 1.]; + let q = array![2., 1., 5.]; + assert!(q.cross_entropy(&p).is_err()); + assert!(p.cross_entropy(&q).is_err()); + assert!(q.kl_divergence(&p).is_err()); + assert!(p.kl_divergence(&q).is_err()); + } + + #[test] + fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { + // p: 3x2, 6 elements + let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]]; + // q: 2x3, 6 elements + let q = array![[2., 1., 5.], [1., 1., 7.],]; + assert!(q.cross_entropy(&p).is_err()); + assert!(p.cross_entropy(&q).is_err()); + assert!(q.kl_divergence(&p).is_err()); + assert!(p.kl_divergence(&q).is_err()); + } + + #[test] + fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> { + let p: Array1 = array![]; + let q: Array1 = array![]; + assert!(p.cross_entropy(&q)?.is_none()); + assert!(p.kl_divergence(&q)?.is_none()); + Ok(()) + } + + #[test] + fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> { + let p = array![1.]; + let q = array![-1.]; + let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap(); + let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap(); + assert!(cross_entropy.is_nan()); + assert!(kl_divergence.is_nan()); + Ok(()) + } + + #[test] + #[should_panic] + fn test_cross_entropy_with_noisy_negative_qs() { + let p = array![n64(1.)]; + let q = array![n64(-1.)]; + let _ = p.cross_entropy(&q); + } + + #[test] + #[should_panic] + fn test_kl_with_noisy_negative_qs() { + let p = array![n64(1.)]; + let q = array![n64(-1.)]; + let _ = p.kl_divergence(&q); + } + + #[test] + fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> { + let p = array![0., 0.]; + let q = array![0., 0.5]; + assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.); + assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.); + Ok(()) + } + + #[test] + fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( + ) -> Result<(), ShapeMismatch> { + let p = array![0.5, 0.5]; + let mut q = array![0.5, 0.]; + assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); + assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY); + Ok(()) + } + + #[test] + fn test_cross_entropy() -> Result<(), ShapeMismatch> { + // Arrays of probability values - normalized and positive. + let p: Array1 = array![ + 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, + 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174, + 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246, + 0.00727477, 0.01004402, 0.01854399, 0.03504082, + ]; + let q: Array1 = array![ + 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812, + 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292, + 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064, + 0.01813342, 0.0007763, 0.0735472, 0.05857833, + ]; + // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) + let expected_cross_entropy = 3.385347705020779; + + assert_abs_diff_eq!( + p.cross_entropy(&q)?.unwrap(), + expected_cross_entropy, + epsilon = 1e-6 + ); + Ok(()) + } + + #[test] + fn test_kl() -> Result<(), ShapeMismatch> { + // Arrays of probability values - normalized and positive. + let p: Array1 = array![ + 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, + 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498, + 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487, + 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063, + 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, + 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831, + 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394, + 0.01108706, + ]; + let q: Array1 = array![ + 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717, + 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548, + 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771, + 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438, + 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686, + 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325, + 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769, + 0.02082707, + ]; + // Computed using scipy.stats.entropy(p, q) + let expected_kl = 0.3555862567800096; + + assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6); + Ok(()) + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 00000000..4bbeea46 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,24 @@ +//! Custom errors returned from our methods and functions. +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +/// An error used by methods and functions that take two arrays as argument and +/// expect them to have exactly the same shape +/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). +pub struct ShapeMismatch { + pub first_shape: Vec, + pub second_shape: Vec, +} + +impl fmt::Display for ShapeMismatch { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Array shapes do not match: {:?} and {:?}.", + self.first_shape, self.second_shape + ) + } +} + +impl Error for ShapeMismatch {} diff --git a/src/lib.rs b/src/lib.rs index 7fbcc94f..9cf586f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,13 @@ //! the *n*-dimensional array data structure provided by [`ndarray`]. //! //! Currently available routines include: -//! - [`order statistics`] (minimum, maximum, quantiles, etc.); -//! - [`partitioning`]; -//! - [`correlation analysis`] (covariance, pearson correlation); -//! - [`histogram computation`]. +//! - [order statistics] (minimum, maximum, quantiles, etc.); +//! - [partitioning]; +//! - [correlation analysis] (covariance, pearson correlation); +//! - [measures from information theory] (entropy, KL divergence, etc.); +//! - [histogram computation]. //! -//! Please feel free to contribute new functionality! A roadmap can be found [`here`]. +//! Please feel free to contribute new functionality! A roadmap can be found [here]. //! //! Our work is inspired by other existing statistical packages such as //! [`NumPy`] (Python) and [`StatsBase.jl`] (Julia) - any contribution bringing us closer to @@ -15,11 +16,12 @@ //! //! [`ndarray-stats`]: https://github.com/jturner314/ndarray-stats/ //! [`ndarray`]: https://github.com/rust-ndarray/ndarray -//! [`order statistics`]: trait.QuantileExt.html -//! [`partitioning`]: trait.Sort1dExt.html -//! [`correlation analysis`]: trait.CorrelationExt.html -//! [`histogram computation`]: histogram/index.html -//! [`here`]: https://github.com/jturner314/ndarray-stats/issues/1 +//! [order statistics]: trait.QuantileExt.html +//! [partitioning]: trait.Sort1dExt.html +//! [correlation analysis]: trait.CorrelationExt.html +//! [measures from information theory]: trait.EntropyExt.html +//! [histogram computation]: histogram/index.html +//! [here]: https://github.com/jturner314/ndarray-stats/issues/1 //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ @@ -37,6 +39,7 @@ extern crate ndarray_rand; extern crate quickcheck; pub use correlation::CorrelationExt; +pub use entropy::EntropyExt; pub use histogram::HistogramExt; pub use maybe_nan::{MaybeNan, MaybeNanExt}; pub use quantile::{interpolate, Quantile1dExt, QuantileExt}; @@ -44,6 +47,8 @@ pub use sort::Sort1dExt; pub use summary_statistics::SummaryStatisticsExt; mod correlation; +mod entropy; +pub mod errors; pub mod histogram; mod maybe_nan; mod quantile;