From 76986a9bcfde8eff020b78372c269a79e24a440e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 22 Jan 2019 08:55:00 +0000 Subject: [PATCH 01/47] Add entropy trait --- src/entropy.rs | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 ++ 2 files changed, 56 insertions(+) create mode 100644 src/entropy.rs diff --git a/src/entropy.rs b/src/entropy.rs new file mode 100644 index 00000000..3fa40fd0 --- /dev/null +++ b/src/entropy.rs @@ -0,0 +1,54 @@ +//! Summary statistics (e.g. mean, variance, etc.). +use ndarray::{ArrayBase, Data, Dimension}; +use num_traits::{FromPrimitive, Float, Zero}; +use std::ops::{Add, Div}; + +/// Extension trait for `ArrayBase` providing methods +/// to compute information theory quantities +/// (e.g. entropy, Kullback–Leibler divergence, etc.). +pub trait EntropyExt + where + S: Data, + D: Dimension, +{ + /// Computes the [entropy] *S* of the array values, defined as + /// + /// ```text + /// n + /// S = - ∑ xᵢ ln(xᵢ) + /// i=1 + /// ``` + /// + /// If the array is empty, `None` is returned. + /// + /// **Panics** if any element in the array is negative. + /// + /// ## Remarks + /// + /// The entropy is a measure used in [Information Theory] + /// to describe a probability distribution: it only make sense + /// when the array values sum to 1, with each entry between + /// 0 and 1 (extremes included). + /// + /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. + /// + /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + fn entropy(&self) -> Option + where + A: Float + FromPrimitive; +} + + +impl EntropyExt for ArrayBase + where + S: Data, + D: Dimension, +{ + fn entropy(&self) -> Option + where + A: Float + FromPrimitive + { + unimplemented!() + } +} diff --git a/src/lib.rs b/src/lib.rs index 5d764a67..1ecedf27 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,10 +43,12 @@ pub use sort::Sort1dExt; pub use correlation::CorrelationExt; pub use histogram::HistogramExt; pub use summary_statistics::SummaryStatisticsExt; +pub use entropy::EntropyExt; mod maybe_nan; mod quantile; mod sort; mod correlation; +mod entropy; mod summary_statistics; pub mod histogram; From 139585939c502bc86967a498c1583564d54a1f9e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 22 Jan 2019 08:55:26 +0000 Subject: [PATCH 02/47] Remove unnecessary imports --- src/entropy.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 3fa40fd0..78a54f17 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,7 +1,6 @@ //! Summary statistics (e.g. mean, variance, etc.). use ndarray::{ArrayBase, Data, Dimension}; -use num_traits::{FromPrimitive, Float, Zero}; -use std::ops::{Add, Div}; +use num_traits::{FromPrimitive, Float}; /// Extension trait for `ArrayBase` providing methods /// to compute information theory quantities From b741b3d68a9a317eecb1aa0f2c07b63501d031fb Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 22 Jan 2019 09:11:30 +0000 Subject: [PATCH 03/47] Implemented entropy --- src/entropy.rs | 61 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 78a54f17..e31593b3 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,6 +1,6 @@ //! Summary statistics (e.g. mean, variance, etc.). use ndarray::{ArrayBase, Data, Dimension}; -use num_traits::{FromPrimitive, Float}; +use num_traits::Float; /// Extension trait for `ArrayBase` providing methods /// to compute information theory quantities @@ -35,7 +35,7 @@ pub trait EntropyExt /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory fn entropy(&self) -> Option where - A: Float + FromPrimitive; + A: Float; } @@ -46,8 +46,61 @@ impl EntropyExt for ArrayBase { fn entropy(&self) -> Option where - A: Float + FromPrimitive + A: Float { - unimplemented!() + if self.len() == 0 { + None + } else { + let entropy = self.map( + |x| { + if *x == A::zero() { + A::zero() + } else { + *x * x.ln() + } + } + ).sum(); + Some(entropy) + } + } +} + +#[cfg(test)] +mod tests { + use super::EntropyExt; + use std::f64; + use approx::abs_diff_eq; + use ndarray::{array, Array1}; + + #[test] + fn test_entropy_with_nan_values() { + let a = array![f64::NAN, 1.]; + assert!(a.entropy().unwrap().is_nan()); + } + + #[test] + fn test_entropy_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert!(a.entropy().is_none()); + } + + #[test] + fn test_entropy_with_array_of_floats() { + let a: Array1 = array![ + 0.70850547, 0.32496524, 0.4512601 , 0.19634812, 0.52430767, + 0.77200268, 0.30947147, 0.01089479, 0.04280482, 0.18548377, + 0.7886273 , 0.23487162, 0.54353668, 0.43455954, 0.8224537 , + 0.60031256, 0.69876954, 0.95906628, 0.20305543, 0.85397668, + 0.50892232, 0.65533253, 0.64384601, 0.86091271, 0.31692328, + 0.45576697, 0.66077109, 0.23469551, 0.42808089, 0.20234666, + 0.14972765, 0.34240363, 0.59198436, 0.05764641, 0.10238259, + 0.06544647, 0.74466137, 0.58182716, 0.5583189 , 0.36093108, + 0.60681015, 0.45062613, 0.83282631, 0.77114486, 0.35229367, + 0.36383337, 0.78485847, 0.56853643, 0.80326787, 0.04409981, + ]; + // Computed using scipy.stats.entropy + let expected_entropy = 3.7371557453896727; + + abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = f64::EPSILON); } } From 7b4f6d5a3c290105994b07336da8ba5bfa1186e3 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 22 Jan 2019 09:12:48 +0000 Subject: [PATCH 04/47] Return entropy with reversed sign, as per definition --- src/entropy.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index e31593b3..2a99a8a2 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -60,7 +60,7 @@ impl EntropyExt for ArrayBase } } ).sum(); - Some(entropy) + Some(-entropy) } } } @@ -98,8 +98,8 @@ mod tests { 0.60681015, 0.45062613, 0.83282631, 0.77114486, 0.35229367, 0.36383337, 0.78485847, 0.56853643, 0.80326787, 0.04409981, ]; - // Computed using scipy.stats.entropy - let expected_entropy = 3.7371557453896727; + // Computed using scipy.stats.entropy, sign has been reversed + let expected_entropy = -3.7371557453896727; abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = f64::EPSILON); } From cc221e218b31425faf125eec335c22908d9f7f39 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 22 Jan 2019 09:20:25 +0000 Subject: [PATCH 05/47] Fixed tests --- src/entropy.rs | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 2a99a8a2..fe4b796d 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -69,7 +69,7 @@ impl EntropyExt for ArrayBase mod tests { use super::EntropyExt; use std::f64; - use approx::abs_diff_eq; + use approx::assert_abs_diff_eq; use ndarray::{array, Array1}; #[test] @@ -86,21 +86,22 @@ mod tests { #[test] fn test_entropy_with_array_of_floats() { + // Array of probability values - normalized and positive. let a: Array1 = array![ - 0.70850547, 0.32496524, 0.4512601 , 0.19634812, 0.52430767, - 0.77200268, 0.30947147, 0.01089479, 0.04280482, 0.18548377, - 0.7886273 , 0.23487162, 0.54353668, 0.43455954, 0.8224537 , - 0.60031256, 0.69876954, 0.95906628, 0.20305543, 0.85397668, - 0.50892232, 0.65533253, 0.64384601, 0.86091271, 0.31692328, - 0.45576697, 0.66077109, 0.23469551, 0.42808089, 0.20234666, - 0.14972765, 0.34240363, 0.59198436, 0.05764641, 0.10238259, - 0.06544647, 0.74466137, 0.58182716, 0.5583189 , 0.36093108, - 0.60681015, 0.45062613, 0.83282631, 0.77114486, 0.35229367, - 0.36383337, 0.78485847, 0.56853643, 0.80326787, 0.04409981, + 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, + 0.03368976, 0.00065396, 0.02906146, 0.00063687, 0.01597306, + 0.00787625, 0.00208243, 0.01450896, 0.01803418, 0.02055336, + 0.03029759, 0.03323628, 0.01218822, 0.0001873 , 0.01734179, + 0.03521668, 0.02564429, 0.02421992, 0.03540229, 0.03497635, + 0.03582331, 0.026558 , 0.02460495, 0.02437716, 0.01212838, + 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, + 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, + 0.00976694, 0.02864634, 0.00802828, 0.03464088, 0.03557152, + 0.01398894, 0.01831756, 0.0227171 , 0.00736204, 0.01866295, ]; - // Computed using scipy.stats.entropy, sign has been reversed - let expected_entropy = -3.7371557453896727; + // Computed using scipy.stats.entropy + let expected_entropy = 3.721606155686918; - abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = f64::EPSILON); + assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); } } From 747381573b76516f840cfa0c9e9448615a121141 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 23 Jan 2019 09:04:37 +0000 Subject: [PATCH 06/47] Added signature for cross entropy --- src/entropy.rs | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index fe4b796d..da718978 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -36,6 +36,40 @@ pub trait EntropyExt fn entropy(&self) -> Option where A: Float; + + /// Computes the [cross entropy] *H(p,q)* between two arrays, + /// where `self`=*p*. + /// + /// The cross entropy is defined as: + /// + /// ```text + /// n + /// H(p,q) = - ∑ pᵢ ln(qᵢ) + /// i=1 + /// ``` + /// + /// If the arrays are empty or their lenghts are not equal, `None` is returned. + /// + /// **Panics** if any element in *q* is negative. + /// + /// ## Remarks + /// + /// The cross entropy is a measure used in [Information Theory] + /// to describe the relationship between two probability distribution: it only make sense + /// when each array sums to 1 with entries between 0 and 1 (extremes included). + /// + /// The cross entropy is often used as an objective/loss function in + /// [optimization problems], including [machine learning]. + /// + /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0. + /// + /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method + /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression + fn cross_entropy(&self, q: &Self) -> Option + where + A: Float; } @@ -63,6 +97,13 @@ impl EntropyExt for ArrayBase Some(-entropy) } } + + fn cross_entropy(&self, q: &Self) -> Option + where + A: Float + { + unimplemented!() + } } #[cfg(test)] From ea3e81ff426a00ea86ac01fc1264b9e6ca766eb9 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 23 Jan 2019 09:05:08 +0000 Subject: [PATCH 07/47] Fixed typo. --- src/entropy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entropy.rs b/src/entropy.rs index da718978..b4ea2958 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -48,7 +48,7 @@ pub trait EntropyExt /// i=1 /// ``` /// - /// If the arrays are empty or their lenghts are not equal, `None` is returned. + /// If the arrays are empty or their lengths are not equal, `None` is returned. /// /// **Panics** if any element in *q* is negative. /// From 2998395f44be62b60e92d3eee047fcd2a54a52d9 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 16:28:37 +0000 Subject: [PATCH 08/47] Implemented cross_entropy --- src/entropy.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index b4ea2958..f7eae8a5 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,5 +1,5 @@ //! Summary statistics (e.g. mean, variance, etc.). -use ndarray::{ArrayBase, Data, Dimension}; +use ndarray::{Array1, ArrayBase, Data, Dimension}; use num_traits::Float; /// Extension trait for `ArrayBase` providing methods @@ -102,7 +102,20 @@ impl EntropyExt for ArrayBase where A: Float { - unimplemented!() + if (self.len() == 0) | (self.len() != q.len()) { + None + } else { + let cross_entropy: A = self.iter().zip(q.iter()).map( + |(p, q)| { + if *p == A::zero() { + A::zero() + } else { + *p * q.ln() + } + } + ).collect::>().sum(); + Some(-cross_entropy) + } } } From bfc3d222bd18f611485e925e2c91d5e146300f30 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 16:31:43 +0000 Subject: [PATCH 09/47] Added tests --- src/entropy.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index f7eae8a5..f8fb554d 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -158,4 +158,35 @@ mod tests { assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); } + + #[test] + fn test_cross_entropy_with_nan_values() { + let a = array![f64::NAN, 1.]; + let b = array![2., 1.]; + assert!(a.cross_entropy(&b).unwrap().is_nan()); + assert!(b.cross_entropy(&a).unwrap().is_nan()); + } + + #[test] + fn test_cross_entropy_with_dimension_mismatch() { + let p = array![f64::NAN, 1.]; + let q = array![2., 1., 5.]; + assert!(q.cross_entropy(&p).is_none()); + assert!(p.cross_entropy(&q).is_none()); + } + + #[test] + fn test_cross_entropy_with_empty_array_of_floats() { + let p: Array1 = array![]; + let q: Array1 = array![]; + assert!(p.cross_entropy(&q).is_none()); + } + + #[test] + #[should_panic] + fn test_cross_entropy_with_negative_qs() { + let p = array![2., 1., 5.]; + let q = array![2., -1., 5.]; + assert!(p.cross_entropy(&q).is_none()); + } } From a69ed91bbd78a617c1fa44c98aaa8c65f3a2b92d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 16:42:29 +0000 Subject: [PATCH 10/47] Refined panic condition --- src/entropy.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index f8fb554d..bbfd353e 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -50,7 +50,8 @@ pub trait EntropyExt /// /// If the arrays are empty or their lengths are not equal, `None` is returned. /// - /// **Panics** if any element in *q* is negative. + /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number + /// is a panic cause for `A`. /// /// ## Remarks /// @@ -124,6 +125,7 @@ mod tests { use super::EntropyExt; use std::f64; use approx::assert_abs_diff_eq; + use noisy_float::types::n64; use ndarray::{array, Array1}; #[test] @@ -183,10 +185,25 @@ mod tests { } #[test] - #[should_panic] fn test_cross_entropy_with_negative_qs() { - let p = array![2., 1., 5.]; - let q = array![2., -1., 5.]; - assert!(p.cross_entropy(&q).is_none()); + let p = array![1.]; + let q = array![-1.]; + let cross_entropy: f64 = p.cross_entropy(&q).unwrap(); + assert!(cross_entropy.is_nan()); + } + + #[test] + #[should_panic] + fn test_cross_entropy_with_noisy_negative_qs() { + let p = array![n64(1.)]; + let q = array![n64(-1.)]; + p.cross_entropy(&q); + } + + #[test] + fn test_cross_entropy_with_zeroes_p() { + let p = array![0., 0.]; + let q = array![0.5, 0.5]; + assert_eq!(p.cross_entropy(&q).unwrap(), 0.); } } From 3d7929bcc518557052985f87e1710023a98d0b04 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 16:52:29 +0000 Subject: [PATCH 11/47] Added test vs SciPy --- src/entropy.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index bbfd353e..3ea43c6d 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -206,4 +206,27 @@ mod tests { let q = array![0.5, 0.5]; assert_eq!(p.cross_entropy(&q).unwrap(), 0.); } + + #[test] + fn test_cross_entropy() { + // Arrays of probability values - normalized and positive. + let p: Array1 = array![ + 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, + 0.05782189, 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, + 0.01959158, 0.05020174, 0.03801479, 0.00092234, 0.08515856, 0.00580683, + 0.0156542, 0.0860375, 0.0724246, 0.00727477, 0.01004402, 0.01854399, + 0.03504082, + ]; + let q: Array1 = array![ + 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, + 0.05604812, 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, + 0.0625685, 0.07381292, 0.05489067, 0.01385491, 0.03639174, 0.00511611, + 0.05700415, 0.05183825, 0.06703064, 0.01813342, 0.0007763, 0.0735472, + 0.05857833, + ]; + // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) + let expected_cross_entropy = 3.385347705020779; + + assert_abs_diff_eq!(p.cross_entropy(&q).unwrap(), expected_cross_entropy, epsilon = 1e-6); + } } From 27dbd00086de62a4d30d3738fe267db71a3923b7 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 16:59:11 +0000 Subject: [PATCH 12/47] Added test vs SciPy --- src/entropy.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 3ea43c6d..6b5b785f 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -203,10 +203,17 @@ mod tests { #[test] fn test_cross_entropy_with_zeroes_p() { let p = array![0., 0.]; - let q = array![0.5, 0.5]; + let q = array![0., 0.5]; assert_eq!(p.cross_entropy(&q).unwrap(), 0.); } + #[test] + fn test_cross_entropy_with_zeroes_q() { + let p = array![0.5, 0.5]; + let q = array![0.5, 0.]; + assert_eq!(p.cross_entropy(&q).unwrap(), f64::INFINITY); + } + #[test] fn test_cross_entropy() { // Arrays of probability values - normalized and positive. @@ -218,11 +225,11 @@ mod tests { 0.03504082, ]; let q: Array1 = array![ - 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, - 0.05604812, 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, - 0.0625685, 0.07381292, 0.05489067, 0.01385491, 0.03639174, 0.00511611, - 0.05700415, 0.05183825, 0.06703064, 0.01813342, 0.0007763, 0.0735472, - 0.05857833, + 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, + 0.05604812, 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, + 0.0625685, 0.07381292, 0.05489067, 0.01385491, 0.03639174, 0.00511611, + 0.05700415, 0.05183825, 0.06703064, 0.01813342, 0.0007763, 0.0735472, + 0.05857833, ]; // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) let expected_cross_entropy = 3.385347705020779; From 873871b701f3eea9913d8e00268921342b6874f2 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 17:06:10 +0000 Subject: [PATCH 13/47] Added KL divergence --- src/entropy.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index 6b5b785f..1a8366f5 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -37,6 +37,36 @@ pub trait EntropyExt where A: Float; + /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays, + /// where `self`=*p*. + /// + /// The Kullback-Leibler divergence is defined as: + /// + /// ```text + /// n + /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ) + /// i=1 + /// ``` + /// + /// If the arrays are empty or their lengths are not equal, `None` is returned. + /// + /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number + /// is a panic cause for `A`. + /// + /// ## Remarks + /// + /// The Kullback-Leibler divergence is a measure used in [Information Theory] + /// to describe the relationship between two probability distribution: it only make sense + /// when each array sums to 1 with entries between 0 and 1 (extremes included). + /// + /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. + /// + /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory + fn kullback_leibler_divergence(&self, q: &Self) -> Option + where + A: Float; + /// Computes the [cross entropy] *H(p,q)* between two arrays, /// where `self`=*p*. /// @@ -99,6 +129,26 @@ impl EntropyExt for ArrayBase } } + fn kullback_leibler_divergence(&self, q: &Self) -> Option + where + A: Float + { + if (self.len() == 0) | (self.len() != q.len()) { + None + } else { + let kl_divergence: A = self.iter().zip(q.iter()).map( + |(p, q)| { + if *p == A::zero() { + A::zero() + } else { + *p * (*q / *p).ln() + } + } + ).collect::>().sum(); + Some(-kl_divergence) + } + } + fn cross_entropy(&self, q: &Self) -> Option where A: Float From dc85e9abe830c4489098f7cecf468f172e134ba9 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 17:12:32 +0000 Subject: [PATCH 14/47] Added KL tests --- src/entropy.rs | 62 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 1a8366f5..d86b1195 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -212,34 +212,41 @@ mod tests { } #[test] - fn test_cross_entropy_with_nan_values() { + fn test_cross_entropy_and_kl_with_nan_values() { let a = array![f64::NAN, 1.]; let b = array![2., 1.]; assert!(a.cross_entropy(&b).unwrap().is_nan()); assert!(b.cross_entropy(&a).unwrap().is_nan()); + assert!(a.kullback_leibler_divergence(&b).unwrap().is_nan()); + assert!(b.kullback_leibler_divergence(&a).unwrap().is_nan()); } #[test] - fn test_cross_entropy_with_dimension_mismatch() { + fn test_cross_entropy_and_kl_with_dimension_mismatch() { let p = array![f64::NAN, 1.]; let q = array![2., 1., 5.]; assert!(q.cross_entropy(&p).is_none()); assert!(p.cross_entropy(&q).is_none()); + assert!(q.kullback_leibler_divergence(&p).is_none()); + assert!(p.kullback_leibler_divergence(&q).is_none()); } #[test] - fn test_cross_entropy_with_empty_array_of_floats() { + fn test_cross_entropy_and_kl_with_empty_array_of_floats() { let p: Array1 = array![]; let q: Array1 = array![]; assert!(p.cross_entropy(&q).is_none()); + assert!(p.kullback_leibler_divergence(&q).is_none()); } #[test] - fn test_cross_entropy_with_negative_qs() { + fn test_cross_entropy_and_kl_with_negative_qs() { let p = array![1.]; let q = array![-1.]; let cross_entropy: f64 = p.cross_entropy(&q).unwrap(); + let kl_divergence: f64 = p.kullback_leibler_divergence(&q).unwrap(); assert!(cross_entropy.is_nan()); + assert!(kl_divergence.is_nan()); } #[test] @@ -251,17 +258,27 @@ mod tests { } #[test] - fn test_cross_entropy_with_zeroes_p() { + #[should_panic] + fn test_kl_with_noisy_negative_qs() { + let p = array![n64(1.)]; + let q = array![n64(-1.)]; + p.kullback_leibler_divergence(&q); + } + + #[test] + fn test_cross_entropy_and_kl_with_zeroes_p() { let p = array![0., 0.]; let q = array![0., 0.5]; assert_eq!(p.cross_entropy(&q).unwrap(), 0.); + assert_eq!(p.kullback_leibler_divergence(&q).unwrap(), 0.); } #[test] - fn test_cross_entropy_with_zeroes_q() { + fn test_cross_entropy_and_kl_with_zeroes_q() { let p = array![0.5, 0.5]; let q = array![0.5, 0.]; assert_eq!(p.cross_entropy(&q).unwrap(), f64::INFINITY); + assert_eq!(p.kullback_leibler_divergence(&q).unwrap(), f64::INFINITY); } #[test] @@ -286,4 +303,37 @@ mod tests { assert_abs_diff_eq!(p.cross_entropy(&q).unwrap(), expected_cross_entropy, epsilon = 1e-6); } + + #[test] + fn test_kl() { + // Arrays of probability values - normalized and positive. + let p: Array1 = array![ + 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, + 0.02183501, 0.00137516, 0.02213802, 0.02745017, 0.02163975, + 0.0324602 , 0.03622766, 0.00782343, 0.00222498, 0.03028156, + 0.02346124, 0.00071105, 0.00794496, 0.0127609 , 0.02899124, + 0.01281487, 0.0230803 , 0.01531864, 0.00518158, 0.02233383, + 0.0220279 , 0.03196097, 0.03710063, 0.01817856, 0.03524661, + 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, + 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375 , + 0.01988341, 0.02621831, 0.03564644, 0.01389121, 0.03151622, + 0.03195532, 0.00717521, 0.03547256, 0.00371394, 0.01108706, + ]; + let q: Array1 = array![ + 0.02038386, 0.03143914, 0.02630206, 0.0171595 , 0.0067072 , + 0.00911324, 0.02635717, 0.01269113, 0.0302361 , 0.02243133, + 0.01902902, 0.01297185, 0.02118908, 0.03309548, 0.01266687, + 0.0184529 , 0.01830936, 0.03430437, 0.02898924, 0.02238251, + 0.0139771 , 0.01879774, 0.02396583, 0.03019978, 0.01421278, + 0.02078981, 0.03542451, 0.02887438, 0.01261783, 0.01014241, + 0.03263407, 0.0095969 , 0.01923903, 0.0051315 , 0.00924686, + 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, + 0.03315135, 0.02099325, 0.03251755, 0.00337555, 0.03432165, + 0.01763753, 0.02038337, 0.01923023, 0.01438769, 0.02082707, + ]; + // Computed using scipy.stats.entropy(p, q) + let expected_kl = 0.3555862567800096; + + assert_abs_diff_eq!(p.kullback_leibler_divergence(&q).unwrap(), expected_kl, epsilon = 1e-6); + } } From ca957888b9aa23d8352bae839125132087d4d599 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 27 Jan 2019 17:14:01 +0000 Subject: [PATCH 15/47] Renamed to kl_divergence --- src/entropy.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index d86b1195..cfecd50a 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -63,7 +63,7 @@ pub trait EntropyExt /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn kullback_leibler_divergence(&self, q: &Self) -> Option + fn kl_divergence(&self, q: &Self) -> Option where A: Float; @@ -129,7 +129,7 @@ impl EntropyExt for ArrayBase } } - fn kullback_leibler_divergence(&self, q: &Self) -> Option + fn kl_divergence(&self, q: &Self) -> Option where A: Float { @@ -217,8 +217,8 @@ mod tests { let b = array![2., 1.]; assert!(a.cross_entropy(&b).unwrap().is_nan()); assert!(b.cross_entropy(&a).unwrap().is_nan()); - assert!(a.kullback_leibler_divergence(&b).unwrap().is_nan()); - assert!(b.kullback_leibler_divergence(&a).unwrap().is_nan()); + assert!(a.kl_divergence(&b).unwrap().is_nan()); + assert!(b.kl_divergence(&a).unwrap().is_nan()); } #[test] @@ -227,8 +227,8 @@ mod tests { let q = array![2., 1., 5.]; assert!(q.cross_entropy(&p).is_none()); assert!(p.cross_entropy(&q).is_none()); - assert!(q.kullback_leibler_divergence(&p).is_none()); - assert!(p.kullback_leibler_divergence(&q).is_none()); + assert!(q.kl_divergence(&p).is_none()); + assert!(p.kl_divergence(&q).is_none()); } #[test] @@ -236,7 +236,7 @@ mod tests { let p: Array1 = array![]; let q: Array1 = array![]; assert!(p.cross_entropy(&q).is_none()); - assert!(p.kullback_leibler_divergence(&q).is_none()); + assert!(p.kl_divergence(&q).is_none()); } #[test] @@ -244,7 +244,7 @@ mod tests { let p = array![1.]; let q = array![-1.]; let cross_entropy: f64 = p.cross_entropy(&q).unwrap(); - let kl_divergence: f64 = p.kullback_leibler_divergence(&q).unwrap(); + let kl_divergence: f64 = p.kl_divergence(&q).unwrap(); assert!(cross_entropy.is_nan()); assert!(kl_divergence.is_nan()); } @@ -262,7 +262,7 @@ mod tests { fn test_kl_with_noisy_negative_qs() { let p = array![n64(1.)]; let q = array![n64(-1.)]; - p.kullback_leibler_divergence(&q); + p.kl_divergence(&q); } #[test] @@ -270,7 +270,7 @@ mod tests { let p = array![0., 0.]; let q = array![0., 0.5]; assert_eq!(p.cross_entropy(&q).unwrap(), 0.); - assert_eq!(p.kullback_leibler_divergence(&q).unwrap(), 0.); + assert_eq!(p.kl_divergence(&q).unwrap(), 0.); } #[test] @@ -278,7 +278,7 @@ mod tests { let p = array![0.5, 0.5]; let q = array![0.5, 0.]; assert_eq!(p.cross_entropy(&q).unwrap(), f64::INFINITY); - assert_eq!(p.kullback_leibler_divergence(&q).unwrap(), f64::INFINITY); + assert_eq!(p.kl_divergence(&q).unwrap(), f64::INFINITY); } #[test] @@ -334,6 +334,6 @@ mod tests { // Computed using scipy.stats.entropy(p, q) let expected_kl = 0.3555862567800096; - assert_abs_diff_eq!(p.kullback_leibler_divergence(&q).unwrap(), expected_kl, epsilon = 1e-6); + assert_abs_diff_eq!(p.kl_divergence(&q).unwrap(), expected_kl, epsilon = 1e-6); } } From d21a0bb6a23aadf354b671c0e22d04bdd9d59e73 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 25 Feb 2019 23:09:56 +0000 Subject: [PATCH 16/47] Update src/entropy.rs Co-Authored-By: LukeMathWalker --- src/entropy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entropy.rs b/src/entropy.rs index cfecd50a..9f892089 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -86,7 +86,7 @@ pub trait EntropyExt /// ## Remarks /// /// The cross entropy is a measure used in [Information Theory] - /// to describe the relationship between two probability distribution: it only make sense + /// to describe the relationship between two probability distributions: it only makes sense /// when each array sums to 1 with entries between 0 and 1 (extremes included). /// /// The cross entropy is often used as an objective/loss function in From c1274281e93eeb312f8fb8d88a83a4f0fed6e489 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:30:49 +0000 Subject: [PATCH 17/47] Improved docs on behaviour with not normalised arrays --- src/entropy.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index cfecd50a..ace99171 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -28,6 +28,9 @@ pub trait EntropyExt /// to describe a probability distribution: it only make sense /// when the array values sum to 1, with each entry between /// 0 and 1 (extremes included). + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). /// /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. /// From 0106d65b55dd25ffa7c7002658ff7858aedd0160 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:31:18 +0000 Subject: [PATCH 18/47] Improved docs on behaviour with not normalised arrays --- src/entropy.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index ace99171..a9ffad35 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -28,6 +28,7 @@ pub trait EntropyExt /// to describe a probability distribution: it only make sense /// when the array values sum to 1, with each entry between /// 0 and 1 (extremes included). + /// /// The array values are **not** normalised by this function before /// computing the entropy to avoid introducing potentially /// unnecessary numerical errors (e.g. if the array were to be already normalised). @@ -62,6 +63,10 @@ pub trait EntropyExt /// to describe the relationship between two probability distribution: it only make sense /// when each array sums to 1 with entries between 0 and 1 (extremes included). /// + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). + /// /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence From b28f4615e783f894533f4ceebdc50b166c0bab02 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:32:08 +0000 Subject: [PATCH 19/47] Use mapv --- src/entropy.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index a9ffad35..07bd250e 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -124,12 +124,12 @@ impl EntropyExt for ArrayBase if self.len() == 0 { None } else { - let entropy = self.map( + let entropy = self.mapv( |x| { - if *x == A::zero() { + if x == A::zero() { A::zero() } else { - *x * x.ln() + x * x.ln() } } ).sum(); From ddf358bdf43a73fc887e5cc53c0edfb7db542a28 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:33:07 +0000 Subject: [PATCH 20/47] Styling on closures (avoid dereferencing) --- src/entropy.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 07bd250e..12cb01a4 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -145,11 +145,11 @@ impl EntropyExt for ArrayBase None } else { let kl_divergence: A = self.iter().zip(q.iter()).map( - |(p, q)| { - if *p == A::zero() { + |(&p, &q)| { + if p == A::zero() { A::zero() } else { - *p * (*q / *p).ln() + p * (q / p).ln() } } ).collect::>().sum(); From afdcf0635f0aaf2bc72099731001ece79dc0030e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:39:56 +0000 Subject: [PATCH 21/47] Allow different data ownership to interact in kl_divergence --- src/entropy.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 12cb01a4..81bb7736 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -137,9 +137,10 @@ impl EntropyExt for ArrayBase } } - fn kl_divergence(&self, q: &Self) -> Option - where - A: Float + fn kl_divergence(&self, q: &ArrayBase) -> Option + where + A: Float, + S2: Data, { if (self.len() == 0) | (self.len() != q.len()) { None From 28b4efd75e859aa138681ab90386d286b32ab7ba Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:40:56 +0000 Subject: [PATCH 22/47] Allow different data ownership to interact in kl_divergence --- src/entropy.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 81bb7736..fb9cd837 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -71,9 +71,10 @@ pub trait EntropyExt /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn kl_divergence(&self, q: &Self) -> Option - where - A: Float; + fn kl_divergence(&self, q: &ArrayBase) -> Option + where + S2: Data, + A: Float; /// Computes the [cross entropy] *H(p,q)* between two arrays, /// where `self`=*p*. From 8c04f9cd88a82caf4faa44b7064b60843d48500c Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:41:48 +0000 Subject: [PATCH 23/47] Allow different data ownership to interact in cross_entropy --- src/entropy.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index fb9cd837..223c96d8 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -107,8 +107,9 @@ pub trait EntropyExt /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression - fn cross_entropy(&self, q: &Self) -> Option + fn cross_entropy(&self, q: &ArrayBase) -> Option where + S2: Data, A: Float; } @@ -159,9 +160,10 @@ impl EntropyExt for ArrayBase } } - fn cross_entropy(&self, q: &Self) -> Option - where - A: Float + fn cross_entropy(&self, q: &ArrayBase) -> Option + where + S2: Data, + A: Float, { if (self.len() == 0) | (self.len() != q.len()) { None From 450cfb436087edabd33c4fc7ccd8854966afd77c Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:43:03 +0000 Subject: [PATCH 24/47] Add a test --- src/entropy.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 223c96d8..54f5710e 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -286,11 +286,11 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_zeroes_q() { + fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership() { let p = array![0.5, 0.5]; - let q = array![0.5, 0.]; - assert_eq!(p.cross_entropy(&q).unwrap(), f64::INFINITY); - assert_eq!(p.kl_divergence(&q).unwrap(), f64::INFINITY); + let mut q = array![0.5, 0.]; + assert_eq!(p.cross_entropy(&q.view_mut()).unwrap(), f64::INFINITY); + assert_eq!(p.kl_divergence(&q.view_mut()).unwrap(), f64::INFINITY); } #[test] From 5d45bdf6f6c1941b4e933c40c2b9cee9a22b47dc Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:43:34 +0000 Subject: [PATCH 25/47] Doc improvement --- src/entropy.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/entropy.rs b/src/entropy.rs index 54f5710e..113534e9 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -98,6 +98,10 @@ pub trait EntropyExt /// to describe the relationship between two probability distribution: it only make sense /// when each array sums to 1 with entries between 0 and 1 (extremes included). /// + /// The array values are **not** normalised by this function before + /// computing the entropy to avoid introducing potentially + /// unnecessary numerical errors (e.g. if the array were to be already normalised). + /// /// The cross entropy is often used as an objective/loss function in /// [optimization problems], including [machine learning]. /// From 5c72f5520bd6783aa05fb8c5803516154eb4f70f Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:50:54 +0000 Subject: [PATCH 26/47] Check the whole shape --- src/entropy.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 113534e9..1f537e04 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -148,7 +148,7 @@ impl EntropyExt for ArrayBase A: Float, S2: Data, { - if (self.len() == 0) | (self.len() != q.len()) { + if (self.len() == 0) | (self.shape() != q.shape()) { None } else { let kl_divergence: A = self.iter().zip(q.iter()).map( @@ -169,7 +169,7 @@ impl EntropyExt for ArrayBase S2: Data, A: Float, { - if (self.len() == 0) | (self.len() != q.len()) { + if (self.len() == 0) | (self.shape() != q.shape()) { None } else { let cross_entropy: A = self.iter().zip(q.iter()).map( @@ -238,7 +238,7 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_dimension_mismatch() { + fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { let p = array![f64::NAN, 1.]; let q = array![2., 1., 5.]; assert!(q.cross_entropy(&p).is_none()); @@ -247,6 +247,25 @@ mod tests { assert!(p.kl_divergence(&q).is_none()); } + #[test] + fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { + // p: 3x2, 6 elements + let p = array![ + [f64::NAN, 1.], + [6., 7.], + [10., 20.] + ]; + // q: 2x3, 6 elements + let q = array![ + [2., 1., 5.], + [1., 1., 7.], + ]; + assert!(q.cross_entropy(&p).is_none()); + assert!(p.cross_entropy(&q).is_none()); + assert!(q.kl_divergence(&p).is_none()); + assert!(p.kl_divergence(&q).is_none()); + } + #[test] fn test_cross_entropy_and_kl_with_empty_array_of_floats() { let p: Array1 = array![]; From bb3876350ff62f9a89b015efe088faafc358eebb Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 08:54:04 +0000 Subject: [PATCH 27/47] Fix docs --- src/entropy.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index f2db4516..4976c6ed 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -52,7 +52,7 @@ pub trait EntropyExt /// i=1 /// ``` /// - /// If the arrays are empty or their lengths are not equal, `None` is returned. + /// If the arrays are empty or their shapes are not identical, `None` is returned. /// /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number /// is a panic cause for `A`. @@ -87,7 +87,7 @@ pub trait EntropyExt /// i=1 /// ``` /// - /// If the arrays are empty or their lengths are not equal, `None` is returned. + /// If the arrays are empty or their shapes are not identical, `None` is returned. /// /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number /// is a panic cause for `A`. From c470a3a479a1c853a35259c3668b4f61e0ab17f4 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 09:02:12 +0000 Subject: [PATCH 28/47] Broken usage of Zip --- src/entropy.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 4976c6ed..9ed955a6 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,5 +1,5 @@ //! Summary statistics (e.g. mean, variance, etc.). -use ndarray::{Array1, ArrayBase, Data, Dimension}; +use ndarray::{Array1, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; /// Extension trait for `ArrayBase` providing methods @@ -151,15 +151,20 @@ impl EntropyExt for ArrayBase if (self.len() == 0) | (self.shape() != q.shape()) { None } else { - let kl_divergence: A = self.iter().zip(q.iter()).map( - |(&p, &q)| { - if p == A::zero() { - A::zero() - } else { - p * (q / p).ln() + let mut temp = ArrayBase::zeros(self.shape()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * (q / p).ln() + } } - } - ).collect::>().sum(); + }); + let kl_divergence = temp.sum(); Some(-kl_divergence) } } From e4be9b92ed05b73c0a2b1e51b106ef2c19aec22b Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 21:43:24 +0000 Subject: [PATCH 29/47] Fixed zip, mistery --- src/entropy.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 9ed955a6..c49d4702 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,5 +1,5 @@ //! Summary statistics (e.g. mean, variance, etc.). -use ndarray::{Array1, ArrayBase, Data, Dimension, Zip}; +use ndarray::{Array, Array1, ArrayBase, DataOwned, Data, Dimension, Zip}; use num_traits::Float; /// Extension trait for `ArrayBase` providing methods @@ -151,7 +151,7 @@ impl EntropyExt for ArrayBase if (self.len() == 0) | (self.shape() != q.shape()) { None } else { - let mut temp = ArrayBase::zeros(self.shape()); + let mut temp = Array::zeros(self.raw_dim()); Zip::from(&mut temp) .and(self) .and(q) From 57537c3d722b9442ff03fbe4373a74b56bc5b9fd Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Tue, 26 Feb 2019 21:49:51 +0000 Subject: [PATCH 30/47] Use Zip for cross_entropy --- src/entropy.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index c49d4702..1d2b0a31 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,5 +1,5 @@ //! Summary statistics (e.g. mean, variance, etc.). -use ndarray::{Array, Array1, ArrayBase, DataOwned, Data, Dimension, Zip}; +use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; /// Extension trait for `ArrayBase` providing methods @@ -177,15 +177,20 @@ impl EntropyExt for ArrayBase if (self.len() == 0) | (self.shape() != q.shape()) { None } else { - let cross_entropy: A = self.iter().zip(q.iter()).map( - |(p, q)| { - if *p == A::zero() { - A::zero() - } else { - *p * q.ln() + let mut temp = Array::zeros(self.raw_dim()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * q.ln() + } } - } - ).collect::>().sum(); + }); + let cross_entropy = temp.sum(); Some(-cross_entropy) } } From 80198bce800dc1e83af1f347b9ea50632b678691 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 16:56:01 +0000 Subject: [PATCH 31/47] Add failure crate as dependency --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 57795ce8..f4a1fc5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ noisy_float = "0.1.8" num-traits = "0.2" rand = "0.6" itertools = { version = "0.7.0", default-features = false } +failure = "0.1.5" [dev-dependencies] quickcheck = "0.7" From 93371f87f1012ca66daee41632042f0d487634e0 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 16:57:00 +0000 Subject: [PATCH 32/47] Errors module --- src/errors.rs | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/errors.rs diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 00000000..e69de29b From 5f6a0040ad752543ef41ed02989fb6f5a1985b35 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 16:58:38 +0000 Subject: [PATCH 33/47] Use failure crate --- src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 1ecedf27..2681a2c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,8 @@ extern crate noisy_float; extern crate num_traits; extern crate rand; extern crate itertools; +#[macro_use] +extern crate failure; #[cfg(test)] extern crate ndarray_rand; From 42c3600b034a1de8ad02a356ce423337d0cc5c4d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:05:02 +0000 Subject: [PATCH 34/47] Add ShapeMismatch error --- src/errors.rs | 6 ++++++ src/lib.rs | 1 + 2 files changed, 7 insertions(+) diff --git a/src/errors.rs b/src/errors.rs index e69de29b..9de66224 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -0,0 +1,6 @@ +#[derive(Fail, Debug)] +#[fail(display = "Array shapes do not match: {:?} and {:?}.", first_shape, second_shape)] +pub struct ShapeMismatch { + first_shape: Vec, + second_shape: Vec, +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 2681a2c1..6bcfc729 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,4 +53,5 @@ mod sort; mod correlation; mod entropy; mod summary_statistics; +pub mod errors; pub mod histogram; From 05d5c666debab7ed1e8359ddf0e075287b3dba87 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:14:16 +0000 Subject: [PATCH 35/47] Return Result --- src/entropy.rs | 93 ++++++++++++++++++++++++++++---------------------- src/errors.rs | 4 +-- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 1d2b0a31..48b5d147 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,6 +1,7 @@ //! Summary statistics (e.g. mean, variance, etc.). use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; +use crate::errors::ShapeMismatch; /// Extension trait for `ArrayBase` providing methods /// to compute information theory quantities @@ -71,7 +72,7 @@ pub trait EntropyExt /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn kl_divergence(&self, q: &ArrayBase) -> Option + fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> where S2: Data, A: Float; @@ -111,7 +112,7 @@ pub trait EntropyExt /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression - fn cross_entropy(&self, q: &ArrayBase) -> Option + fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> where S2: Data, A: Float; @@ -143,56 +144,68 @@ impl EntropyExt for ArrayBase } } - fn kl_divergence(&self, q: &ArrayBase) -> Option + fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> where A: Float, S2: Data, { - if (self.len() == 0) | (self.shape() != q.shape()) { - None - } else { - let mut temp = Array::zeros(self.raw_dim()); - Zip::from(&mut temp) - .and(self) - .and(q) - .apply(|result, &p, &q| { - *result = { - if p == A::zero() { - A::zero() - } else { - p * (q / p).ln() - } - } - }); - let kl_divergence = temp.sum(); - Some(-kl_divergence) + if self.len() == 0 { + return Ok(None) + } + if self.shape() != q.shape() { + return Err(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: q.shape().to_vec() + }) } + + let mut temp = Array::zeros(self.raw_dim()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * (q / p).ln() + } + } + }); + let kl_divergence = temp.sum(); + Ok(Some(-kl_divergence)) } - fn cross_entropy(&self, q: &ArrayBase) -> Option + fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> where S2: Data, A: Float, { - if (self.len() == 0) | (self.shape() != q.shape()) { - None - } else { - let mut temp = Array::zeros(self.raw_dim()); - Zip::from(&mut temp) - .and(self) - .and(q) - .apply(|result, &p, &q| { - *result = { - if p == A::zero() { - A::zero() - } else { - p * q.ln() - } - } - }); - let cross_entropy = temp.sum(); - Some(-cross_entropy) + if self.len() == 0 { + return Ok(None) + } + if self.shape() != q.shape() { + return Err(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: q.shape().to_vec() + }) } + + let mut temp = Array::zeros(self.raw_dim()); + Zip::from(&mut temp) + .and(self) + .and(q) + .apply(|result, &p, &q| { + *result = { + if p == A::zero() { + A::zero() + } else { + p * q.ln() + } + } + }); + let cross_entropy = temp.sum(); + Ok(Some(-cross_entropy)) } } diff --git a/src/errors.rs b/src/errors.rs index 9de66224..a7e67bed 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,6 +1,6 @@ #[derive(Fail, Debug)] #[fail(display = "Array shapes do not match: {:?} and {:?}.", first_shape, second_shape)] pub struct ShapeMismatch { - first_shape: Vec, - second_shape: Vec, + pub first_shape: Vec, + pub second_shape: Vec, } \ No newline at end of file From 99a391ee503428196087a0ec92055f8c8cf86cb1 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:21:27 +0000 Subject: [PATCH 36/47] Fix test suite --- src/entropy.rs | 70 ++++++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 48b5d147..892a549d 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -216,6 +216,7 @@ mod tests { use approx::assert_abs_diff_eq; use noisy_float::types::n64; use ndarray::{array, Array1}; + use errors::ShapeMismatch; #[test] fn test_entropy_with_nan_values() { @@ -251,23 +252,24 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_nan_values() { + fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> { let a = array![f64::NAN, 1.]; let b = array![2., 1.]; - assert!(a.cross_entropy(&b).unwrap().is_nan()); - assert!(b.cross_entropy(&a).unwrap().is_nan()); - assert!(a.kl_divergence(&b).unwrap().is_nan()); - assert!(b.kl_divergence(&a).unwrap().is_nan()); + assert!(a.cross_entropy(&b)?.unwrap().is_nan()); + assert!(b.cross_entropy(&a)?.unwrap().is_nan()); + assert!(a.kl_divergence(&b)?.unwrap().is_nan()); + assert!(b.kl_divergence(&a)?.unwrap().is_nan()); + Ok(()) } #[test] fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { let p = array![f64::NAN, 1.]; let q = array![2., 1., 5.]; - assert!(q.cross_entropy(&p).is_none()); - assert!(p.cross_entropy(&q).is_none()); - assert!(q.kl_divergence(&p).is_none()); - assert!(p.kl_divergence(&q).is_none()); + assert!(q.cross_entropy(&p).is_err()); + assert!(p.cross_entropy(&q).is_err()); + assert!(q.kl_divergence(&p).is_err()); + assert!(p.kl_divergence(&q).is_err()); } #[test] @@ -283,28 +285,30 @@ mod tests { [2., 1., 5.], [1., 1., 7.], ]; - assert!(q.cross_entropy(&p).is_none()); - assert!(p.cross_entropy(&q).is_none()); - assert!(q.kl_divergence(&p).is_none()); - assert!(p.kl_divergence(&q).is_none()); + assert!(q.cross_entropy(&p).is_err()); + assert!(p.cross_entropy(&q).is_err()); + assert!(q.kl_divergence(&p).is_err()); + assert!(p.kl_divergence(&q).is_err()); } #[test] - fn test_cross_entropy_and_kl_with_empty_array_of_floats() { + fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> { let p: Array1 = array![]; let q: Array1 = array![]; - assert!(p.cross_entropy(&q).is_none()); - assert!(p.kl_divergence(&q).is_none()); + assert!(p.cross_entropy(&q)?.is_none()); + assert!(p.kl_divergence(&q)?.is_none()); + Ok(()) } #[test] - fn test_cross_entropy_and_kl_with_negative_qs() { + fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> { let p = array![1.]; let q = array![-1.]; - let cross_entropy: f64 = p.cross_entropy(&q).unwrap(); - let kl_divergence: f64 = p.kl_divergence(&q).unwrap(); + let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap(); + let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap(); assert!(cross_entropy.is_nan()); assert!(kl_divergence.is_nan()); + Ok(()) } #[test] @@ -312,7 +316,7 @@ mod tests { fn test_cross_entropy_with_noisy_negative_qs() { let p = array![n64(1.)]; let q = array![n64(-1.)]; - p.cross_entropy(&q); + let _ = p.cross_entropy(&q); } #[test] @@ -320,27 +324,29 @@ mod tests { fn test_kl_with_noisy_negative_qs() { let p = array![n64(1.)]; let q = array![n64(-1.)]; - p.kl_divergence(&q); + let _ = p.kl_divergence(&q); } #[test] - fn test_cross_entropy_and_kl_with_zeroes_p() { + fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> { let p = array![0., 0.]; let q = array![0., 0.5]; - assert_eq!(p.cross_entropy(&q).unwrap(), 0.); - assert_eq!(p.kl_divergence(&q).unwrap(), 0.); + assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.); + assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.); + Ok(()) } #[test] - fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership() { + fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership() -> Result<(), ShapeMismatch> { let p = array![0.5, 0.5]; let mut q = array![0.5, 0.]; - assert_eq!(p.cross_entropy(&q.view_mut()).unwrap(), f64::INFINITY); - assert_eq!(p.kl_divergence(&q.view_mut()).unwrap(), f64::INFINITY); + assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); + assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY); + Ok(()) } #[test] - fn test_cross_entropy() { + fn test_cross_entropy() -> Result<(), ShapeMismatch> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, @@ -359,11 +365,12 @@ mod tests { // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) let expected_cross_entropy = 3.385347705020779; - assert_abs_diff_eq!(p.cross_entropy(&q).unwrap(), expected_cross_entropy, epsilon = 1e-6); + assert_abs_diff_eq!(p.cross_entropy(&q)?.unwrap(), expected_cross_entropy, epsilon = 1e-6); + Ok(()) } #[test] - fn test_kl() { + fn test_kl() -> Result<(), ShapeMismatch> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, @@ -392,6 +399,7 @@ mod tests { // Computed using scipy.stats.entropy(p, q) let expected_kl = 0.3555862567800096; - assert_abs_diff_eq!(p.kl_divergence(&q).unwrap(), expected_kl, epsilon = 1e-6); + assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6); + Ok(()) } } From 3a3d1f66fd2714826496e0e849892a66ae30fa2e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:22:54 +0000 Subject: [PATCH 37/47] Fix docs --- src/entropy.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 892a549d..8139555f 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -53,7 +53,8 @@ pub trait EntropyExt /// i=1 /// ``` /// - /// If the arrays are empty or their shapes are not identical, `None` is returned. + /// If the arrays are empty, Ok(`None`) is returned. + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. /// /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number /// is a panic cause for `A`. @@ -88,7 +89,8 @@ pub trait EntropyExt /// i=1 /// ``` /// - /// If the arrays are empty or their shapes are not identical, `None` is returned. + /// If the arrays are empty, Ok(`None`) is returned. + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. /// /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number /// is a panic cause for `A`. From ca31af83af11a638249df7189361024b61855d1b Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:26:12 +0000 Subject: [PATCH 38/47] Fix docs --- src/entropy.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 8139555f..15282436 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -56,8 +56,8 @@ pub trait EntropyExt /// If the arrays are empty, Ok(`None`) is returned. /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. /// - /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number - /// is a panic cause for `A`. + /// **Panics** if for a pair of elements *(pᵢ, qᵢ)* from *p* and *q* computing + /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. /// /// ## Remarks /// From e65ef6173f6dd508fcfa1c42c64e2ceef58c8faf Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Fri, 8 Mar 2019 17:29:30 +0000 Subject: [PATCH 39/47] Add docs to error --- src/errors.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/errors.rs b/src/errors.rs index a7e67bed..41f3d796 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,5 +1,8 @@ #[derive(Fail, Debug)] #[fail(display = "Array shapes do not match: {:?} and {:?}.", first_shape, second_shape)] +/// An error used by methods and functions that take two arrays as argument and +/// expect them to have exactly the same shape +/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). pub struct ShapeMismatch { pub first_shape: Vec, pub second_shape: Vec, From e39025c983e728bd69b2552442af83cf3ba8983a Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 10 Mar 2019 10:24:04 +0000 Subject: [PATCH 40/47] Update src/entropy.rs Co-Authored-By: LukeMathWalker --- src/entropy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entropy.rs b/src/entropy.rs index 15282436..166b0c46 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,4 +1,4 @@ -//! Summary statistics (e.g. mean, variance, etc.). +//! Information theory (e.g. entropy, KL divergence, etc.). use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; use crate::errors::ShapeMismatch; From 99b999f1b0470c6666ab940ba3f90fec3927a607 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 10 Mar 2019 10:24:16 +0000 Subject: [PATCH 41/47] Update src/entropy.rs Co-Authored-By: LukeMathWalker --- src/entropy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entropy.rs b/src/entropy.rs index 166b0c46..e9679b72 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -21,7 +21,7 @@ pub trait EntropyExt /// /// If the array is empty, `None` is returned. /// - /// **Panics** if any element in the array is negative. + /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). /// /// ## Remarks /// From ac4c1592978329dee4dce7d6d60cd7b31272bf89 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 10 Mar 2019 10:25:06 +0000 Subject: [PATCH 42/47] Update src/entropy.rs Co-Authored-By: LukeMathWalker --- src/entropy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entropy.rs b/src/entropy.rs index e9679b72..6de1feca 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -56,7 +56,7 @@ pub trait EntropyExt /// If the arrays are empty, Ok(`None`) is returned. /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. /// - /// **Panics** if for a pair of elements *(pᵢ, qᵢ)* from *p* and *q* computing + /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. /// /// ## Remarks From b429ec714dbed31e6c099e01c1a9bbb37ffdec3b Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 10 Mar 2019 10:32:07 +0000 Subject: [PATCH 43/47] Better semantic --- src/entropy.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 6de1feca..0bbc6ba4 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -174,8 +174,8 @@ impl EntropyExt for ArrayBase } } }); - let kl_divergence = temp.sum(); - Ok(Some(-kl_divergence)) + let kl_divergence = -temp.sum(); + Ok(Some(kl_divergence)) } fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> @@ -206,8 +206,8 @@ impl EntropyExt for ArrayBase } } }); - let cross_entropy = temp.sum(); - Ok(Some(-cross_entropy)) + let cross_entropy = -temp.sum(); + Ok(Some(cross_entropy)) } } From e9679fa46e19780ea41c9edd4b1499ea21f72ad7 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 10 Mar 2019 10:39:05 +0000 Subject: [PATCH 44/47] Use Error instead of Fail --- Cargo.toml | 1 - src/errors.rs | 16 +++++++++++++--- src/lib.rs | 2 -- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f4a1fc5e..57795ce8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ noisy_float = "0.1.8" num-traits = "0.2" rand = "0.6" itertools = { version = "0.7.0", default-features = false } -failure = "0.1.5" [dev-dependencies] quickcheck = "0.7" diff --git a/src/errors.rs b/src/errors.rs index 41f3d796..37b158df 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,9 +1,19 @@ -#[derive(Fail, Debug)] -#[fail(display = "Array shapes do not match: {:?} and {:?}.", first_shape, second_shape)] +use std::error::Error; +use std::fmt; + +#[derive(Debug)] /// An error used by methods and functions that take two arrays as argument and /// expect them to have exactly the same shape /// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). pub struct ShapeMismatch { pub first_shape: Vec, pub second_shape: Vec, -} \ No newline at end of file +} + +impl fmt::Display for ShapeMismatch { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Array shapes do not match: {:?} and {:?}.", self.first_shape, self.second_shape) + } +} + +impl Error for ShapeMismatch {} diff --git a/src/lib.rs b/src/lib.rs index 6bcfc729..d0761ae4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,8 +29,6 @@ extern crate noisy_float; extern crate num_traits; extern crate rand; extern crate itertools; -#[macro_use] -extern crate failure; #[cfg(test)] extern crate ndarray_rand; From d2dfe8f40d48251c5debfbe45f7e8d0b90a9b86e Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 10 Mar 2019 10:39:48 +0000 Subject: [PATCH 45/47] Formatting --- src/correlation.rs | 96 +++++++++------------ src/entropy.rs | 145 +++++++++++++++----------------- src/errors.rs | 6 +- src/histogram/bins.rs | 35 +++----- src/histogram/grid.rs | 29 ++++--- src/histogram/histograms.rs | 25 +++--- src/histogram/mod.rs | 10 +-- src/histogram/strategies.rs | 86 +++++++++---------- src/lib.rs | 23 +++-- src/quantile.rs | 9 +- src/summary_statistics/means.rs | 47 ++++++----- src/summary_statistics/mod.rs | 21 +++-- 12 files changed, 257 insertions(+), 275 deletions(-) diff --git a/src/correlation.rs b/src/correlation.rs index ec3a1030..200bbc45 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -131,14 +131,15 @@ where { let observation_axis = Axis(1); let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap(); - let dof = - if ddof >= n_observations { - panic!("`ddof` needs to be strictly smaller than the \ - number of observations provided for each \ - random variable!") - } else { - n_observations - ddof - }; + let dof = if ddof >= n_observations { + panic!( + "`ddof` needs to be strictly smaller than the \ + number of observations provided for each \ + random variable!" + ) + } else { + n_observations - ddof + }; let mean = self.mean_axis(observation_axis); let denoised = self - &mean.insert_axis(observation_axis); let covariance = denoised.dot(&denoised.t()); @@ -156,7 +157,9 @@ where // observation per random variable (or no observations at all) let ddof = -A::one(); let cov = self.cov(ddof); - let std = self.std_axis(observation_axis, ddof).insert_axis(observation_axis); + let std = self + .std_axis(observation_axis, ddof) + .insert_axis(observation_axis); let std_matrix = std.dot(&std.t()); // element-wise division cov / std_matrix @@ -167,10 +170,10 @@ where mod cov_tests { use super::*; use ndarray::array; + use ndarray_rand::RandomExt; use quickcheck::quickcheck; use rand; use rand::distributions::Uniform; - use ndarray_rand::RandomExt; quickcheck! { fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool { @@ -200,10 +203,7 @@ mod cov_tests { fn test_invalid_ddof() { let n_random_variables = 3; let n_observations = 4; - let a = Array::random( - (n_random_variables, n_observations), - Uniform::new(0., 10.) - ); + let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.)); let invalid_ddof = (n_observations as f64) + rand::random::().abs(); a.cov(invalid_ddof); } @@ -235,45 +235,36 @@ mod cov_tests { #[test] fn test_covariance_for_random_array() { let a = array![ - [ 0.72009497, 0.12568055, 0.55705966, 0.5959984 , 0.69471457], - [ 0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245], - [ 0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036], - [ 0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539], - [ 0.26952473, 0.93079841, 0.8080893 , 0.42814155, 0.24642258] + [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457], + [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245], + [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036], + [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539], + [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258] ]; let numpy_covariance = array![ - [ 0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992], - [ 0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555], - [ 0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912], - [ 0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557], - [-0.06443992, -0.06715555, -0.06129912, -0.02355557, 0.09909855] + [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992], + [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555], + [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912], + [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557], + [ + -0.06443992, + -0.06715555, + -0.06129912, + -0.02355557, + 0.09909855 + ] ]; assert_eq!(a.ndim(), 2); - assert!( - a.cov(1.).all_close( - &numpy_covariance, - 1e-8 - ) - ); + assert!(a.cov(1.).all_close(&numpy_covariance, 1e-8)); } #[test] #[should_panic] // We lose precision, hence the failing assert fn test_covariance_for_badly_conditioned_array() { - let a: Array2 = array![ - [ 1e12 + 1., 1e12 - 1.], - [ 1e-6 + 1e-12, 1e-6 - 1e-12], - ]; - let expected_covariance = array![ - [2., 2e-12], [2e-12, 2e-24] - ]; - assert!( - a.cov(1.).all_close( - &expected_covariance, - 1e-24 - ) - ); + let a: Array2 = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],]; + let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]]; + assert!(a.cov(1.).all_close(&expected_covariance, 1e-24)); } } @@ -281,9 +272,9 @@ mod cov_tests { mod pearson_correlation_tests { use super::*; use ndarray::array; + use ndarray_rand::RandomExt; use quickcheck::quickcheck; use rand::distributions::Uniform; - use ndarray_rand::RandomExt; quickcheck! { fn output_matrix_is_symmetric(bound: f64) -> bool { @@ -337,19 +328,14 @@ mod pearson_correlation_tests { [0.26979716, 0.20887228, 0.95454999, 0.96290785] ]; let numpy_corrcoeff = array![ - [ 1. , 0.38089376, 0.08122504, -0.59931623, 0.1365648 ], - [ 0.38089376, 1. , 0.80918429, -0.52615195, 0.38954398], - [ 0.08122504, 0.80918429, 1. , 0.07134906, -0.17324776], - [-0.59931623, -0.52615195, 0.07134906, 1. , -0.8743213 ], - [ 0.1365648 , 0.38954398, -0.17324776, -0.8743213 , 1. ] + [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648], + [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398], + [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776], + [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213], + [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.] ]; assert_eq!(a.ndim(), 2); - assert!( - a.pearson_correlation().all_close( - &numpy_corrcoeff, - 1e-7 - ) - ); + assert!(a.pearson_correlation().all_close(&numpy_corrcoeff, 1e-7)); } } diff --git a/src/entropy.rs b/src/entropy.rs index 0bbc6ba4..8daac6d8 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,15 +1,15 @@ //! Information theory (e.g. entropy, KL divergence, etc.). +use crate::errors::ShapeMismatch; use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; use num_traits::Float; -use crate::errors::ShapeMismatch; /// Extension trait for `ArrayBase` providing methods /// to compute information theory quantities /// (e.g. entropy, Kullback–Leibler divergence, etc.). pub trait EntropyExt - where - S: Data, - D: Dimension, +where + S: Data, + D: Dimension, { /// Computes the [entropy] *S* of the array values, defined as /// @@ -75,7 +75,7 @@ pub trait EntropyExt /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> where - S2: Data, + S2: Data, A: Float; /// Computes the [cross entropy] *H(p,q)* between two arrays, @@ -116,32 +116,31 @@ pub trait EntropyExt /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> where - S2: Data, + S2: Data, A: Float; } - impl EntropyExt for ArrayBase - where - S: Data, - D: Dimension, +where + S: Data, + D: Dimension, { fn entropy(&self) -> Option - where - A: Float + where + A: Float, { if self.len() == 0 { None } else { - let entropy = self.mapv( - |x| { + let entropy = self + .mapv(|x| { if x == A::zero() { A::zero() } else { x * x.ln() } - } - ).sum(); + }) + .sum(); Some(-entropy) } } @@ -149,16 +148,16 @@ impl EntropyExt for ArrayBase fn kl_divergence(&self, q: &ArrayBase) -> Result, ShapeMismatch> where A: Float, - S2: Data, + S2: Data, { if self.len() == 0 { - return Ok(None) + return Ok(None); } if self.shape() != q.shape() { return Err(ShapeMismatch { first_shape: self.shape().to_vec(), - second_shape: q.shape().to_vec() - }) + second_shape: q.shape().to_vec(), + }); } let mut temp = Array::zeros(self.raw_dim()); @@ -180,17 +179,17 @@ impl EntropyExt for ArrayBase fn cross_entropy(&self, q: &ArrayBase) -> Result, ShapeMismatch> where - S2: Data, + S2: Data, A: Float, { if self.len() == 0 { - return Ok(None) + return Ok(None); } if self.shape() != q.shape() { return Err(ShapeMismatch { first_shape: self.shape().to_vec(), - second_shape: q.shape().to_vec() - }) + second_shape: q.shape().to_vec(), + }); } let mut temp = Array::zeros(self.raw_dim()); @@ -214,11 +213,11 @@ impl EntropyExt for ArrayBase #[cfg(test)] mod tests { use super::EntropyExt; - use std::f64; use approx::assert_abs_diff_eq; - use noisy_float::types::n64; - use ndarray::{array, Array1}; use errors::ShapeMismatch; + use ndarray::{array, Array1}; + use noisy_float::types::n64; + use std::f64; #[test] fn test_entropy_with_nan_values() { @@ -236,16 +235,14 @@ mod tests { fn test_entropy_with_array_of_floats() { // Array of probability values - normalized and positive. let a: Array1 = array![ - 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, - 0.03368976, 0.00065396, 0.02906146, 0.00063687, 0.01597306, - 0.00787625, 0.00208243, 0.01450896, 0.01803418, 0.02055336, - 0.03029759, 0.03323628, 0.01218822, 0.0001873 , 0.01734179, - 0.03521668, 0.02564429, 0.02421992, 0.03540229, 0.03497635, - 0.03582331, 0.026558 , 0.02460495, 0.02437716, 0.01212838, - 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, - 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, - 0.00976694, 0.02864634, 0.00802828, 0.03464088, 0.03557152, - 0.01398894, 0.01831756, 0.0227171 , 0.00736204, 0.01866295, + 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396, + 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418, + 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668, + 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495, + 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, + 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634, + 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204, + 0.01866295, ]; // Computed using scipy.stats.entropy let expected_entropy = 3.721606155686918; @@ -277,16 +274,9 @@ mod tests { #[test] fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { // p: 3x2, 6 elements - let p = array![ - [f64::NAN, 1.], - [6., 7.], - [10., 20.] - ]; + let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]]; // q: 2x3, 6 elements - let q = array![ - [2., 1., 5.], - [1., 1., 7.], - ]; + let q = array![[2., 1., 5.], [1., 1., 7.],]; assert!(q.cross_entropy(&p).is_err()); assert!(p.cross_entropy(&q).is_err()); assert!(q.kl_divergence(&p).is_err()); @@ -339,7 +329,8 @@ mod tests { } #[test] - fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership() -> Result<(), ShapeMismatch> { + fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( + ) -> Result<(), ShapeMismatch> { let p = array![0.5, 0.5]; let mut q = array![0.5, 0.]; assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); @@ -351,23 +342,25 @@ mod tests { fn test_cross_entropy() -> Result<(), ShapeMismatch> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ - 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, - 0.05782189, 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, - 0.01959158, 0.05020174, 0.03801479, 0.00092234, 0.08515856, 0.00580683, - 0.0156542, 0.0860375, 0.0724246, 0.00727477, 0.01004402, 0.01854399, - 0.03504082, + 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, + 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174, + 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246, + 0.00727477, 0.01004402, 0.01854399, 0.03504082, ]; let q: Array1 = array![ - 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, - 0.05604812, 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, - 0.0625685, 0.07381292, 0.05489067, 0.01385491, 0.03639174, 0.00511611, - 0.05700415, 0.05183825, 0.06703064, 0.01813342, 0.0007763, 0.0735472, - 0.05857833, + 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812, + 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292, + 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064, + 0.01813342, 0.0007763, 0.0735472, 0.05857833, ]; // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) let expected_cross_entropy = 3.385347705020779; - assert_abs_diff_eq!(p.cross_entropy(&q)?.unwrap(), expected_cross_entropy, epsilon = 1e-6); + assert_abs_diff_eq!( + p.cross_entropy(&q)?.unwrap(), + expected_cross_entropy, + epsilon = 1e-6 + ); Ok(()) } @@ -375,28 +368,24 @@ mod tests { fn test_kl() -> Result<(), ShapeMismatch> { // Arrays of probability values - normalized and positive. let p: Array1 = array![ - 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, - 0.02183501, 0.00137516, 0.02213802, 0.02745017, 0.02163975, - 0.0324602 , 0.03622766, 0.00782343, 0.00222498, 0.03028156, - 0.02346124, 0.00071105, 0.00794496, 0.0127609 , 0.02899124, - 0.01281487, 0.0230803 , 0.01531864, 0.00518158, 0.02233383, - 0.0220279 , 0.03196097, 0.03710063, 0.01817856, 0.03524661, - 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, - 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375 , - 0.01988341, 0.02621831, 0.03564644, 0.01389121, 0.03151622, - 0.03195532, 0.00717521, 0.03547256, 0.00371394, 0.01108706, + 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, + 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498, + 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487, + 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063, + 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, + 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831, + 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394, + 0.01108706, ]; let q: Array1 = array![ - 0.02038386, 0.03143914, 0.02630206, 0.0171595 , 0.0067072 , - 0.00911324, 0.02635717, 0.01269113, 0.0302361 , 0.02243133, - 0.01902902, 0.01297185, 0.02118908, 0.03309548, 0.01266687, - 0.0184529 , 0.01830936, 0.03430437, 0.02898924, 0.02238251, - 0.0139771 , 0.01879774, 0.02396583, 0.03019978, 0.01421278, - 0.02078981, 0.03542451, 0.02887438, 0.01261783, 0.01014241, - 0.03263407, 0.0095969 , 0.01923903, 0.0051315 , 0.00924686, - 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, - 0.03315135, 0.02099325, 0.03251755, 0.00337555, 0.03432165, - 0.01763753, 0.02038337, 0.01923023, 0.01438769, 0.02082707, + 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717, + 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548, + 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771, + 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438, + 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686, + 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325, + 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769, + 0.02082707, ]; // Computed using scipy.stats.entropy(p, q) let expected_kl = 0.3555862567800096; diff --git a/src/errors.rs b/src/errors.rs index 37b158df..f3e9f77f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,7 +12,11 @@ pub struct ShapeMismatch { impl fmt::Display for ShapeMismatch { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Array shapes do not match: {:?} and {:?}.", self.first_shape, self.second_shape) + write!( + f, + "Array shapes do not match: {:?} and {:?}.", + self.first_shape, self.second_shape + ) } } diff --git a/src/histogram/bins.rs b/src/histogram/bins.rs index e0c3a1ea..3a83eae9 100644 --- a/src/histogram/bins.rs +++ b/src/histogram/bins.rs @@ -33,7 +33,6 @@ pub struct Edges { } impl From> for Edges { - /// Get an `Edges` instance from a `Vec`: /// the vector will be sorted in increasing order /// using an unstable sorting algorithm and duplicates @@ -89,7 +88,7 @@ impl From> for Edges { } } -impl Index for Edges{ +impl Index for Edges { type Output = A; /// Get the `i`-th edge. @@ -182,13 +181,11 @@ impl Edges { match self.edges.binary_search(value) { Ok(i) if i == n_edges - 1 => None, Ok(i) => Some((i, i + 1)), - Err(i) => { - match i { - 0 => None, - j if j == n_edges => None, - j => Some((j - 1, j)), - } - } + Err(i) => match i { + 0 => None, + j if j == n_edges => None, + j => Some((j - 1, j)), + }, } } @@ -309,18 +306,14 @@ impl Bins { /// ); /// ``` pub fn range_of(&self, value: &A) -> Option> - where - A: Clone, + where + A: Clone, { let edges_indexes = self.edges.indices_of(value); - edges_indexes.map( - |(left, right)| { - Range { - start: self.edges[left].clone(), - end: self.edges[right].clone(), - } - } - ) + edges_indexes.map(|(left, right)| Range { + start: self.edges[left].clone(), + end: self.edges[right].clone(), + }) } /// Get the `i`-th bin. @@ -341,7 +334,7 @@ impl Bins { /// ); /// ``` pub fn index(&self, index: usize) -> Range - where + where A: Clone, { // It was not possible to implement this functionality @@ -350,7 +343,7 @@ impl Bins { // Index, in fact, forces you to return a reference. Range { start: self.edges[index].clone(), - end: self.edges[index+1].clone(), + end: self.edges[index + 1].clone(), } } } diff --git a/src/histogram/grid.rs b/src/histogram/grid.rs index 32a7161b..bfa5afc1 100644 --- a/src/histogram/grid.rs +++ b/src/histogram/grid.rs @@ -1,8 +1,8 @@ use super::bins::Bins; use super::strategies::BinsBuildingStrategy; -use std::ops::Range; use itertools::izip; -use ndarray::{ArrayBase, Data, Ix1, Ix2, Axis}; +use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2}; +use std::ops::Range; /// A `Grid` is a partition of a rectangular region of an *n*-dimensional /// space—e.g. [*a*0, *b*0) × ⋯ × [*a**n*−1, @@ -72,7 +72,6 @@ pub struct Grid { } impl From>> for Grid { - /// Get a `Grid` instance from a `Vec>`. /// /// The `i`-th element in `Vec>` represents the 1-dimensional @@ -113,9 +112,14 @@ impl Grid { where S: Data, { - assert_eq!(point.len(), self.ndim(), - "Dimension mismatch: the point has {:?} dimensions, the grid \ - expected {:?} dimensions.", point.len(), self.ndim()); + assert_eq!( + point.len(), + self.ndim(), + "Dimension mismatch: the point has {:?} dimensions, the grid \ + expected {:?} dimensions.", + point.len(), + self.ndim() + ); point .iter() .zip(self.projections.iter()) @@ -132,9 +136,14 @@ impl Grid { /// **Panics** if at least one among `(i_0, ..., i_{n-1})` is out of bounds on the respective /// coordinate axis - i.e. if there exists `j` such that `i_j >= self.projections[j].len()`. pub fn index(&self, index: &[usize]) -> Vec> { - assert_eq!(index.len(), self.ndim(), - "Dimension mismatch: the index has {0:?} dimensions, the grid \ - expected {1:?} dimensions.", index.len(), self.ndim()); + assert_eq!( + index.len(), + self.ndim(), + "Dimension mismatch: the index has {0:?} dimensions, the grid \ + expected {1:?} dimensions.", + index.len(), + self.ndim() + ); izip!(&self.projections, index) .map(|(bins, &i)| bins.index(i)) .collect() @@ -164,7 +173,7 @@ where /// [`strategy`]: strategies/index.html pub fn from_array(array: &ArrayBase) -> Self where - S: Data, + S: Data, { let bin_builders = array .axis_iter(Axis(1)) diff --git a/src/histogram/histograms.rs b/src/histogram/histograms.rs index 825aadb7..9bfe2724 100644 --- a/src/histogram/histograms.rs +++ b/src/histogram/histograms.rs @@ -1,7 +1,7 @@ +use super::errors::BinNotFound; +use super::grid::Grid; use ndarray::prelude::*; use ndarray::Data; -use super::grid::Grid; -use super::errors::BinNotFound; /// Histogram data structure. pub struct Histogram { @@ -58,8 +58,8 @@ impl Histogram { Some(bin_index) => { self.counts[&*bin_index] += 1; Ok(()) - }, - None => Err(BinNotFound) + } + None => Err(BinNotFound), } } @@ -82,8 +82,8 @@ impl Histogram { /// Extension trait for `ArrayBase` providing methods to compute histograms. pub trait HistogramExt - where - S: Data, +where + S: Data, { /// Returns the [histogram](https://en.wikipedia.org/wiki/Histogram) /// for a 2-dimensional array of points `M`. @@ -145,17 +145,16 @@ pub trait HistogramExt /// # } /// ``` fn histogram(&self, grid: Grid) -> Histogram - where - A: Ord; + where + A: Ord; } impl HistogramExt for ArrayBase - where - S: Data, - A: Ord, +where + S: Data, + A: Ord, { - fn histogram(&self, grid: Grid) -> Histogram - { + fn histogram(&self, grid: Grid) -> Histogram { let mut histogram = Histogram::new(grid); for point in self.axis_iter(Axis(0)) { let _ = histogram.add_observation(&point); diff --git a/src/histogram/mod.rs b/src/histogram/mod.rs index 9176aee1..3acbf40a 100644 --- a/src/histogram/mod.rs +++ b/src/histogram/mod.rs @@ -1,10 +1,10 @@ //! Histogram functionalities. -pub use self::histograms::{Histogram, HistogramExt}; -pub use self::bins::{Edges, Bins}; +pub use self::bins::{Bins, Edges}; pub use self::grid::{Grid, GridBuilder}; +pub use self::histograms::{Histogram, HistogramExt}; -mod histograms; mod bins; -pub mod strategies; -mod grid; pub mod errors; +mod grid; +mod histograms; +pub mod strategies; diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index f669b35e..eeaee686 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -18,13 +18,12 @@ //! [`Grid`]: ../struct.Grid.html //! [`GridBuilder`]: ../struct.GridBuilder.html //! [`NumPy`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges +use super::super::interpolate::Nearest; +use super::super::{Quantile1dExt, QuantileExt}; +use super::{Bins, Edges}; use ndarray::prelude::*; use ndarray::Data; use num_traits::{FromPrimitive, NumOps, Zero}; -use super::super::{QuantileExt, Quantile1dExt}; -use super::super::interpolate::Nearest; -use super::{Edges, Bins}; - /// A trait implemented by all strategies to build [`Bins`] /// with parameters inferred from observations. @@ -36,8 +35,7 @@ use super::{Edges, Bins}; /// [`Bins`]: ../struct.Bins.html /// [`Grid`]: ../struct.Grid.html /// [`GridBuilder`]: ../struct.GridBuilder.html -pub trait BinsBuildingStrategy -{ +pub trait BinsBuildingStrategy { type Elem: Ord; /// Given some observations in a 1-dimensional array it returns a `BinsBuildingStrategy` /// that has learned the required parameter to build a collection of [`Bins`]. @@ -45,7 +43,7 @@ pub trait BinsBuildingStrategy /// [`Bins`]: ../struct.Bins.html fn from_array(array: &ArrayBase) -> Self where - S: Data; + S: Data; /// Returns a [`Bins`] instance, built accordingly to the parameters /// inferred from observations in [`from_array`]. @@ -140,21 +138,24 @@ pub struct Auto { } impl EquiSpaced - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { /// **Panics** if `bin_width<=0`. - fn new(bin_width: T, min: T, max: T) -> Self - { + fn new(bin_width: T, min: T, max: T) -> Self { assert!(bin_width > T::zero()); - Self { bin_width, min, max } + Self { + bin_width, + min, + max, + } } fn build(&self) -> Bins { let n_bins = self.n_bins(); let mut edges: Vec = vec![]; - for i in 0..(n_bins+1) { - let edge = self.min.clone() + T::from_usize(i).unwrap()*self.bin_width.clone(); + for i in 0..(n_bins + 1) { + let edge = self.min.clone() + T::from_usize(i).unwrap() * self.bin_width.clone(); edges.push(edge); } Bins::new(Edges::from(edges)) @@ -167,7 +168,7 @@ impl EquiSpaced max_edge = max_edge + self.bin_width.clone(); n_bins += 1; } - return n_bins + return n_bins; } fn bin_width(&self) -> T { @@ -176,15 +177,15 @@ impl EquiSpaced } impl BinsBuildingStrategy for Sqrt - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; /// **Panics** if the array is constant or if `a.len()==0`. fn from_array(a: &ArrayBase) -> Self where - S: Data + S: Data, { let n_elems = a.len(); let n_bins = (n_elems as f64).sqrt().round() as usize; @@ -205,8 +206,8 @@ impl BinsBuildingStrategy for Sqrt } impl Sqrt - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -215,18 +216,18 @@ impl Sqrt } impl BinsBuildingStrategy for Rice - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; /// **Panics** if the array is constant or if `a.len()==0`. fn from_array(a: &ArrayBase) -> Self where - S: Data + S: Data, { let n_elems = a.len(); - let n_bins = (2. * (n_elems as f64).powf(1./3.)).round() as usize; + let n_bins = (2. * (n_elems as f64).powf(1. / 3.)).round() as usize; let min = a.min().unwrap().clone(); let max = a.max().unwrap().clone(); let bin_width = compute_bin_width(min.clone(), max.clone(), n_bins); @@ -244,8 +245,8 @@ impl BinsBuildingStrategy for Rice } impl Rice - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -254,15 +255,15 @@ impl Rice } impl BinsBuildingStrategy for Sturges - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; /// **Panics** if the array is constant or if `a.len()==0`. fn from_array(a: &ArrayBase) -> Self where - S: Data + S: Data, { let n_elems = a.len(); let n_bins = (n_elems as f64).log2().round() as usize + 1; @@ -283,8 +284,8 @@ impl BinsBuildingStrategy for Sturges } impl Sturges - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -293,15 +294,15 @@ impl Sturges } impl BinsBuildingStrategy for FreedmanDiaconis - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; /// **Panics** if `IQR==0` or if `a.len()==0`. fn from_array(a: &ArrayBase) -> Self where - S: Data + S: Data, { let n_points = a.len(); @@ -327,11 +328,10 @@ impl BinsBuildingStrategy for FreedmanDiaconis } impl FreedmanDiaconis - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { - fn compute_bin_width(n_bins: usize, iqr: T) -> T - { + fn compute_bin_width(n_bins: usize, iqr: T) -> T { let denominator = (n_bins as f64).powf(1. / 3.); let bin_width = T::from_usize(2).unwrap() * iqr / T::from_f64(denominator).unwrap(); bin_width @@ -344,15 +344,15 @@ impl FreedmanDiaconis } impl BinsBuildingStrategy for Auto - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; /// **Panics** if `IQR==0`, the array is constant, or `a.len()==0`. fn from_array(a: &ArrayBase) -> Self where - S: Data + S: Data, { let fd_builder = FreedmanDiaconis::from_array(&a); let sturges_builder = Sturges::from_array(&a); @@ -384,8 +384,8 @@ impl BinsBuildingStrategy for Auto } impl Auto - where - T: Ord + Clone + FromPrimitive + NumOps + Zero +where + T: Ord + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { diff --git a/src/lib.rs b/src/lib.rs index d0761ae4..28bf81a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,33 +23,32 @@ //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ - +extern crate itertools; extern crate ndarray; extern crate noisy_float; extern crate num_traits; extern crate rand; -extern crate itertools; +#[cfg(test)] +extern crate approx; #[cfg(test)] extern crate ndarray_rand; #[cfg(test)] extern crate quickcheck; -#[cfg(test)] -extern crate approx; -pub use maybe_nan::{MaybeNan, MaybeNanExt}; -pub use quantile::{interpolate, QuantileExt, Quantile1dExt}; -pub use sort::Sort1dExt; pub use correlation::CorrelationExt; +pub use entropy::EntropyExt; pub use histogram::HistogramExt; +pub use maybe_nan::{MaybeNan, MaybeNanExt}; +pub use quantile::{interpolate, Quantile1dExt, QuantileExt}; +pub use sort::Sort1dExt; pub use summary_statistics::SummaryStatisticsExt; -pub use entropy::EntropyExt; -mod maybe_nan; -mod quantile; -mod sort; mod correlation; mod entropy; -mod summary_statistics; pub mod errors; pub mod histogram; +mod maybe_nan; +mod quantile; +mod sort; +mod summary_statistics; diff --git a/src/quantile.rs b/src/quantile.rs index 9188aa2f..c29ccba2 100644 --- a/src/quantile.rs +++ b/src/quantile.rs @@ -378,8 +378,8 @@ where /// Quantile methods for 1-D arrays. pub trait Quantile1dExt - where - S: Data, +where + S: Data, { /// Return the qth quantile of the data. /// @@ -418,8 +418,8 @@ pub trait Quantile1dExt } impl Quantile1dExt for ArrayBase - where - S: Data, +where + S: Data, { fn quantile_mut(&mut self, q: f64) -> Option where @@ -434,4 +434,3 @@ impl Quantile1dExt for ArrayBase } } } - diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index 97217c23..f1059efd 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -1,8 +1,7 @@ -use ndarray::{Data, Dimension, ArrayBase}; -use num_traits::{FromPrimitive, Float, Zero}; -use std::ops::{Add, Div}; use super::SummaryStatisticsExt; - +use ndarray::{ArrayBase, Data, Dimension}; +use num_traits::{Float, FromPrimitive, Zero}; +use std::ops::{Add, Div}; impl SummaryStatisticsExt for ArrayBase where @@ -11,7 +10,7 @@ where { fn mean(&self) -> Option where - A: Clone + FromPrimitive + Add + Div + Zero + A: Clone + FromPrimitive + Add + Div + Zero, { let n_elements = self.len(); if n_elements == 0 { @@ -31,8 +30,8 @@ where } fn geometric_mean(&self) -> Option - where - A: Float + FromPrimitive, + where + A: Float + FromPrimitive, { self.map(|x| x.ln()).mean().map(|x| x.exp()) } @@ -41,10 +40,10 @@ where #[cfg(test)] mod tests { use super::SummaryStatisticsExt; - use std::f64; use approx::abs_diff_eq; - use noisy_float::types::N64; use ndarray::{array, Array1}; + use noisy_float::types::N64; + use std::f64; #[test] fn test_means_with_nan_values() { @@ -73,16 +72,14 @@ mod tests { #[test] fn test_means_with_array_of_floats() { let a: Array1 = array![ - 0.99889651, 0.0150731 , 0.28492482, 0.83819218, 0.48413156, - 0.80710412, 0.41762936, 0.22879429, 0.43997224, 0.23831807, - 0.02416466, 0.6269962 , 0.47420614, 0.56275487, 0.78995021, - 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, - 0.34429457, 0.88072369, 0.17638164, 0.60819363, 0.250392 , - 0.69912532, 0.78855523, 0.79140914, 0.85084218, 0.31839879, - 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, - 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, - 0.63608897, 0.84959691, 0.43599069, 0.77867775, 0.88267754, - 0.83003623, 0.67016118, 0.67547638, 0.65220036, 0.68043427 + 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, + 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, + 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, + 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, + 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, + 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, + 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, + 0.68043427 ]; // Computed using NumPy let expected_mean = 0.5475494059146699; @@ -92,7 +89,15 @@ mod tests { let expected_geometric_mean = 0.4345897639796527; abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = f64::EPSILON); - abs_diff_eq!(a.harmonic_mean().unwrap(), expected_harmonic_mean, epsilon = f64::EPSILON); - abs_diff_eq!(a.geometric_mean().unwrap(), expected_geometric_mean, epsilon = f64::EPSILON); + abs_diff_eq!( + a.harmonic_mean().unwrap(), + expected_harmonic_mean, + epsilon = f64::EPSILON + ); + abs_diff_eq!( + a.geometric_mean().unwrap(), + expected_geometric_mean, + epsilon = f64::EPSILON + ); } } diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index ae05e709..6aca865f 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -1,14 +1,14 @@ //! Summary statistics (e.g. mean, variance, etc.). use ndarray::{Data, Dimension}; -use num_traits::{FromPrimitive, Float, Zero}; +use num_traits::{Float, FromPrimitive, Zero}; use std::ops::{Add, Div}; /// Extension trait for `ArrayBase` providing methods /// to compute several summary statistics (e.g. mean, variance, etc.). pub trait SummaryStatisticsExt - where - S: Data, - D: Dimension, +where + S: Data, + D: Dimension, { /// Returns the [`arithmetic mean`] x̅ of all elements in the array: /// @@ -24,8 +24,8 @@ pub trait SummaryStatisticsExt /// /// [`arithmetic mean`]: https://en.wikipedia.org/wiki/Arithmetic_mean fn mean(&self) -> Option - where - A: Clone + FromPrimitive + Add + Div + Zero; + where + A: Clone + FromPrimitive + Add + Div + Zero; /// Returns the [`harmonic mean`] `HM(X)` of all elements in the array: /// @@ -41,8 +41,8 @@ pub trait SummaryStatisticsExt /// /// [`harmonic mean`]: https://en.wikipedia.org/wiki/Harmonic_mean fn harmonic_mean(&self) -> Option - where - A: Float + FromPrimitive; + where + A: Float + FromPrimitive; /// Returns the [`geometric mean`] `GM(X)` of all elements in the array: /// @@ -58,9 +58,8 @@ pub trait SummaryStatisticsExt /// /// [`geometric mean`]: https://en.wikipedia.org/wiki/Geometric_mean fn geometric_mean(&self) -> Option - where - A: Float + FromPrimitive; - + where + A: Float + FromPrimitive; } mod means; From b8ed3ede99cab4d0fa05d037f6f17feea9bc68e2 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 10 Mar 2019 10:44:20 +0000 Subject: [PATCH 46/47] Fix TOC --- src/lib.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 28bf81a6..9cf586f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,13 @@ //! the *n*-dimensional array data structure provided by [`ndarray`]. //! //! Currently available routines include: -//! - [`order statistics`] (minimum, maximum, quantiles, etc.); -//! - [`partitioning`]; -//! - [`correlation analysis`] (covariance, pearson correlation); -//! - [`histogram computation`]. +//! - [order statistics] (minimum, maximum, quantiles, etc.); +//! - [partitioning]; +//! - [correlation analysis] (covariance, pearson correlation); +//! - [measures from information theory] (entropy, KL divergence, etc.); +//! - [histogram computation]. //! -//! Please feel free to contribute new functionality! A roadmap can be found [`here`]. +//! Please feel free to contribute new functionality! A roadmap can be found [here]. //! //! Our work is inspired by other existing statistical packages such as //! [`NumPy`] (Python) and [`StatsBase.jl`] (Julia) - any contribution bringing us closer to @@ -15,11 +16,12 @@ //! //! [`ndarray-stats`]: https://github.com/jturner314/ndarray-stats/ //! [`ndarray`]: https://github.com/rust-ndarray/ndarray -//! [`order statistics`]: trait.QuantileExt.html -//! [`partitioning`]: trait.Sort1dExt.html -//! [`correlation analysis`]: trait.CorrelationExt.html -//! [`histogram computation`]: histogram/index.html -//! [`here`]: https://github.com/jturner314/ndarray-stats/issues/1 +//! [order statistics]: trait.QuantileExt.html +//! [partitioning]: trait.Sort1dExt.html +//! [correlation analysis]: trait.CorrelationExt.html +//! [measures from information theory]: trait.EntropyExt.html +//! [histogram computation]: histogram/index.html +//! [here]: https://github.com/jturner314/ndarray-stats/issues/1 //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ From c961d9f5889c807045c0cd31a534ed71732b1c55 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sun, 10 Mar 2019 10:49:11 +0000 Subject: [PATCH 47/47] Module docstring --- src/errors.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/errors.rs b/src/errors.rs index f3e9f77f..4bbeea46 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,3 +1,4 @@ +//! Custom errors returned from our methods and functions. use std::error::Error; use std::fmt;