|
| 1 | +//! Information theory (e.g. entropy, KL divergence, etc.). |
| 2 | +use crate::errors::ShapeMismatch; |
| 3 | +use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; |
| 4 | +use num_traits::Float; |
| 5 | + |
| 6 | +/// Extension trait for `ArrayBase` providing methods |
| 7 | +/// to compute information theory quantities |
| 8 | +/// (e.g. entropy, Kullback–Leibler divergence, etc.). |
| 9 | +pub trait EntropyExt<A, S, D> |
| 10 | +where |
| 11 | + S: Data<Elem = A>, |
| 12 | + D: Dimension, |
| 13 | +{ |
| 14 | + /// Computes the [entropy] *S* of the array values, defined as |
| 15 | + /// |
| 16 | + /// ```text |
| 17 | + /// n |
| 18 | + /// S = - ∑ xᵢ ln(xᵢ) |
| 19 | + /// i=1 |
| 20 | + /// ``` |
| 21 | + /// |
| 22 | + /// If the array is empty, `None` is returned. |
| 23 | + /// |
| 24 | + /// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`). |
| 25 | + /// |
| 26 | + /// ## Remarks |
| 27 | + /// |
| 28 | + /// The entropy is a measure used in [Information Theory] |
| 29 | + /// to describe a probability distribution: it only make sense |
| 30 | + /// when the array values sum to 1, with each entry between |
| 31 | + /// 0 and 1 (extremes included). |
| 32 | + /// |
| 33 | + /// The array values are **not** normalised by this function before |
| 34 | + /// computing the entropy to avoid introducing potentially |
| 35 | + /// unnecessary numerical errors (e.g. if the array were to be already normalised). |
| 36 | + /// |
| 37 | + /// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0. |
| 38 | + /// |
| 39 | + /// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory) |
| 40 | + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory |
| 41 | + fn entropy(&self) -> Option<A> |
| 42 | + where |
| 43 | + A: Float; |
| 44 | + |
| 45 | + /// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays, |
| 46 | + /// where `self`=*p*. |
| 47 | + /// |
| 48 | + /// The Kullback-Leibler divergence is defined as: |
| 49 | + /// |
| 50 | + /// ```text |
| 51 | + /// n |
| 52 | + /// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ) |
| 53 | + /// i=1 |
| 54 | + /// ``` |
| 55 | + /// |
| 56 | + /// If the arrays are empty, Ok(`None`) is returned. |
| 57 | + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. |
| 58 | + /// |
| 59 | + /// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing |
| 60 | + /// *ln(qᵢ/pᵢ)* is a panic cause for `A`. |
| 61 | + /// |
| 62 | + /// ## Remarks |
| 63 | + /// |
| 64 | + /// The Kullback-Leibler divergence is a measure used in [Information Theory] |
| 65 | + /// to describe the relationship between two probability distribution: it only make sense |
| 66 | + /// when each array sums to 1 with entries between 0 and 1 (extremes included). |
| 67 | + /// |
| 68 | + /// The array values are **not** normalised by this function before |
| 69 | + /// computing the entropy to avoid introducing potentially |
| 70 | + /// unnecessary numerical errors (e.g. if the array were to be already normalised). |
| 71 | + /// |
| 72 | + /// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0. |
| 73 | + /// |
| 74 | + /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence |
| 75 | + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory |
| 76 | + fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> |
| 77 | + where |
| 78 | + S2: Data<Elem = A>, |
| 79 | + A: Float; |
| 80 | + |
| 81 | + /// Computes the [cross entropy] *H(p,q)* between two arrays, |
| 82 | + /// where `self`=*p*. |
| 83 | + /// |
| 84 | + /// The cross entropy is defined as: |
| 85 | + /// |
| 86 | + /// ```text |
| 87 | + /// n |
| 88 | + /// H(p,q) = - ∑ pᵢ ln(qᵢ) |
| 89 | + /// i=1 |
| 90 | + /// ``` |
| 91 | + /// |
| 92 | + /// If the arrays are empty, Ok(`None`) is returned. |
| 93 | + /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned. |
| 94 | + /// |
| 95 | + /// **Panics** if any element in *q* is negative and taking the logarithm of a negative number |
| 96 | + /// is a panic cause for `A`. |
| 97 | + /// |
| 98 | + /// ## Remarks |
| 99 | + /// |
| 100 | + /// The cross entropy is a measure used in [Information Theory] |
| 101 | + /// to describe the relationship between two probability distributions: it only makes sense |
| 102 | + /// when each array sums to 1 with entries between 0 and 1 (extremes included). |
| 103 | + /// |
| 104 | + /// The array values are **not** normalised by this function before |
| 105 | + /// computing the entropy to avoid introducing potentially |
| 106 | + /// unnecessary numerical errors (e.g. if the array were to be already normalised). |
| 107 | + /// |
| 108 | + /// The cross entropy is often used as an objective/loss function in |
| 109 | + /// [optimization problems], including [machine learning]. |
| 110 | + /// |
| 111 | + /// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0. |
| 112 | + /// |
| 113 | + /// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy |
| 114 | + /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory |
| 115 | + /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method |
| 116 | + /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression |
| 117 | + fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> |
| 118 | + where |
| 119 | + S2: Data<Elem = A>, |
| 120 | + A: Float; |
| 121 | +} |
| 122 | + |
| 123 | +impl<A, S, D> EntropyExt<A, S, D> for ArrayBase<S, D> |
| 124 | +where |
| 125 | + S: Data<Elem = A>, |
| 126 | + D: Dimension, |
| 127 | +{ |
| 128 | + fn entropy(&self) -> Option<A> |
| 129 | + where |
| 130 | + A: Float, |
| 131 | + { |
| 132 | + if self.len() == 0 { |
| 133 | + None |
| 134 | + } else { |
| 135 | + let entropy = self |
| 136 | + .mapv(|x| { |
| 137 | + if x == A::zero() { |
| 138 | + A::zero() |
| 139 | + } else { |
| 140 | + x * x.ln() |
| 141 | + } |
| 142 | + }) |
| 143 | + .sum(); |
| 144 | + Some(-entropy) |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> |
| 149 | + where |
| 150 | + A: Float, |
| 151 | + S2: Data<Elem = A>, |
| 152 | + { |
| 153 | + if self.len() == 0 { |
| 154 | + return Ok(None); |
| 155 | + } |
| 156 | + if self.shape() != q.shape() { |
| 157 | + return Err(ShapeMismatch { |
| 158 | + first_shape: self.shape().to_vec(), |
| 159 | + second_shape: q.shape().to_vec(), |
| 160 | + }); |
| 161 | + } |
| 162 | + |
| 163 | + let mut temp = Array::zeros(self.raw_dim()); |
| 164 | + Zip::from(&mut temp) |
| 165 | + .and(self) |
| 166 | + .and(q) |
| 167 | + .apply(|result, &p, &q| { |
| 168 | + *result = { |
| 169 | + if p == A::zero() { |
| 170 | + A::zero() |
| 171 | + } else { |
| 172 | + p * (q / p).ln() |
| 173 | + } |
| 174 | + } |
| 175 | + }); |
| 176 | + let kl_divergence = -temp.sum(); |
| 177 | + Ok(Some(kl_divergence)) |
| 178 | + } |
| 179 | + |
| 180 | + fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch> |
| 181 | + where |
| 182 | + S2: Data<Elem = A>, |
| 183 | + A: Float, |
| 184 | + { |
| 185 | + if self.len() == 0 { |
| 186 | + return Ok(None); |
| 187 | + } |
| 188 | + if self.shape() != q.shape() { |
| 189 | + return Err(ShapeMismatch { |
| 190 | + first_shape: self.shape().to_vec(), |
| 191 | + second_shape: q.shape().to_vec(), |
| 192 | + }); |
| 193 | + } |
| 194 | + |
| 195 | + let mut temp = Array::zeros(self.raw_dim()); |
| 196 | + Zip::from(&mut temp) |
| 197 | + .and(self) |
| 198 | + .and(q) |
| 199 | + .apply(|result, &p, &q| { |
| 200 | + *result = { |
| 201 | + if p == A::zero() { |
| 202 | + A::zero() |
| 203 | + } else { |
| 204 | + p * q.ln() |
| 205 | + } |
| 206 | + } |
| 207 | + }); |
| 208 | + let cross_entropy = -temp.sum(); |
| 209 | + Ok(Some(cross_entropy)) |
| 210 | + } |
| 211 | +} |
| 212 | + |
| 213 | +#[cfg(test)] |
| 214 | +mod tests { |
| 215 | + use super::EntropyExt; |
| 216 | + use approx::assert_abs_diff_eq; |
| 217 | + use errors::ShapeMismatch; |
| 218 | + use ndarray::{array, Array1}; |
| 219 | + use noisy_float::types::n64; |
| 220 | + use std::f64; |
| 221 | + |
| 222 | + #[test] |
| 223 | + fn test_entropy_with_nan_values() { |
| 224 | + let a = array![f64::NAN, 1.]; |
| 225 | + assert!(a.entropy().unwrap().is_nan()); |
| 226 | + } |
| 227 | + |
| 228 | + #[test] |
| 229 | + fn test_entropy_with_empty_array_of_floats() { |
| 230 | + let a: Array1<f64> = array![]; |
| 231 | + assert!(a.entropy().is_none()); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn test_entropy_with_array_of_floats() { |
| 236 | + // Array of probability values - normalized and positive. |
| 237 | + let a: Array1<f64> = array![ |
| 238 | + 0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396, |
| 239 | + 0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418, |
| 240 | + 0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668, |
| 241 | + 0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495, |
| 242 | + 0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588, |
| 243 | + 0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634, |
| 244 | + 0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204, |
| 245 | + 0.01866295, |
| 246 | + ]; |
| 247 | + // Computed using scipy.stats.entropy |
| 248 | + let expected_entropy = 3.721606155686918; |
| 249 | + |
| 250 | + assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6); |
| 251 | + } |
| 252 | + |
| 253 | + #[test] |
| 254 | + fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> { |
| 255 | + let a = array![f64::NAN, 1.]; |
| 256 | + let b = array![2., 1.]; |
| 257 | + assert!(a.cross_entropy(&b)?.unwrap().is_nan()); |
| 258 | + assert!(b.cross_entropy(&a)?.unwrap().is_nan()); |
| 259 | + assert!(a.kl_divergence(&b)?.unwrap().is_nan()); |
| 260 | + assert!(b.kl_divergence(&a)?.unwrap().is_nan()); |
| 261 | + Ok(()) |
| 262 | + } |
| 263 | + |
| 264 | + #[test] |
| 265 | + fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() { |
| 266 | + let p = array![f64::NAN, 1.]; |
| 267 | + let q = array![2., 1., 5.]; |
| 268 | + assert!(q.cross_entropy(&p).is_err()); |
| 269 | + assert!(p.cross_entropy(&q).is_err()); |
| 270 | + assert!(q.kl_divergence(&p).is_err()); |
| 271 | + assert!(p.kl_divergence(&q).is_err()); |
| 272 | + } |
| 273 | + |
| 274 | + #[test] |
| 275 | + fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() { |
| 276 | + // p: 3x2, 6 elements |
| 277 | + let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]]; |
| 278 | + // q: 2x3, 6 elements |
| 279 | + let q = array![[2., 1., 5.], [1., 1., 7.],]; |
| 280 | + assert!(q.cross_entropy(&p).is_err()); |
| 281 | + assert!(p.cross_entropy(&q).is_err()); |
| 282 | + assert!(q.kl_divergence(&p).is_err()); |
| 283 | + assert!(p.kl_divergence(&q).is_err()); |
| 284 | + } |
| 285 | + |
| 286 | + #[test] |
| 287 | + fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> { |
| 288 | + let p: Array1<f64> = array![]; |
| 289 | + let q: Array1<f64> = array![]; |
| 290 | + assert!(p.cross_entropy(&q)?.is_none()); |
| 291 | + assert!(p.kl_divergence(&q)?.is_none()); |
| 292 | + Ok(()) |
| 293 | + } |
| 294 | + |
| 295 | + #[test] |
| 296 | + fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> { |
| 297 | + let p = array![1.]; |
| 298 | + let q = array![-1.]; |
| 299 | + let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap(); |
| 300 | + let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap(); |
| 301 | + assert!(cross_entropy.is_nan()); |
| 302 | + assert!(kl_divergence.is_nan()); |
| 303 | + Ok(()) |
| 304 | + } |
| 305 | + |
| 306 | + #[test] |
| 307 | + #[should_panic] |
| 308 | + fn test_cross_entropy_with_noisy_negative_qs() { |
| 309 | + let p = array![n64(1.)]; |
| 310 | + let q = array![n64(-1.)]; |
| 311 | + let _ = p.cross_entropy(&q); |
| 312 | + } |
| 313 | + |
| 314 | + #[test] |
| 315 | + #[should_panic] |
| 316 | + fn test_kl_with_noisy_negative_qs() { |
| 317 | + let p = array![n64(1.)]; |
| 318 | + let q = array![n64(-1.)]; |
| 319 | + let _ = p.kl_divergence(&q); |
| 320 | + } |
| 321 | + |
| 322 | + #[test] |
| 323 | + fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> { |
| 324 | + let p = array![0., 0.]; |
| 325 | + let q = array![0., 0.5]; |
| 326 | + assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.); |
| 327 | + assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.); |
| 328 | + Ok(()) |
| 329 | + } |
| 330 | + |
| 331 | + #[test] |
| 332 | + fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership( |
| 333 | + ) -> Result<(), ShapeMismatch> { |
| 334 | + let p = array![0.5, 0.5]; |
| 335 | + let mut q = array![0.5, 0.]; |
| 336 | + assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY); |
| 337 | + assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY); |
| 338 | + Ok(()) |
| 339 | + } |
| 340 | + |
| 341 | + #[test] |
| 342 | + fn test_cross_entropy() -> Result<(), ShapeMismatch> { |
| 343 | + // Arrays of probability values - normalized and positive. |
| 344 | + let p: Array1<f64> = array![ |
| 345 | + 0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189, |
| 346 | + 0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174, |
| 347 | + 0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246, |
| 348 | + 0.00727477, 0.01004402, 0.01854399, 0.03504082, |
| 349 | + ]; |
| 350 | + let q: Array1<f64> = array![ |
| 351 | + 0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812, |
| 352 | + 0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292, |
| 353 | + 0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064, |
| 354 | + 0.01813342, 0.0007763, 0.0735472, 0.05857833, |
| 355 | + ]; |
| 356 | + // Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q) |
| 357 | + let expected_cross_entropy = 3.385347705020779; |
| 358 | + |
| 359 | + assert_abs_diff_eq!( |
| 360 | + p.cross_entropy(&q)?.unwrap(), |
| 361 | + expected_cross_entropy, |
| 362 | + epsilon = 1e-6 |
| 363 | + ); |
| 364 | + Ok(()) |
| 365 | + } |
| 366 | + |
| 367 | + #[test] |
| 368 | + fn test_kl() -> Result<(), ShapeMismatch> { |
| 369 | + // Arrays of probability values - normalized and positive. |
| 370 | + let p: Array1<f64> = array![ |
| 371 | + 0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516, |
| 372 | + 0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498, |
| 373 | + 0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487, |
| 374 | + 0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063, |
| 375 | + 0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151, |
| 376 | + 0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831, |
| 377 | + 0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394, |
| 378 | + 0.01108706, |
| 379 | + ]; |
| 380 | + let q: Array1<f64> = array![ |
| 381 | + 0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717, |
| 382 | + 0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548, |
| 383 | + 0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771, |
| 384 | + 0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438, |
| 385 | + 0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686, |
| 386 | + 0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325, |
| 387 | + 0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769, |
| 388 | + 0.02082707, |
| 389 | + ]; |
| 390 | + // Computed using scipy.stats.entropy(p, q) |
| 391 | + let expected_kl = 0.3555862567800096; |
| 392 | + |
| 393 | + assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6); |
| 394 | + Ok(()) |
| 395 | + } |
| 396 | +} |
0 commit comments