Skip to content

Commit 8e0e1cf

Browse files
Entropy (#24)
Measures from information-theory: entropy, cross-entropy, kullback leibler divergence.
1 parent 07110df commit 8e0e1cf

File tree

3 files changed

+435
-10
lines changed

3 files changed

+435
-10
lines changed

src/entropy.rs

+396
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
//! Information theory (e.g. entropy, KL divergence, etc.).
2+
use crate::errors::ShapeMismatch;
3+
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
4+
use num_traits::Float;
5+
6+
/// Extension trait for `ArrayBase` providing methods
7+
/// to compute information theory quantities
8+
/// (e.g. entropy, Kullback–Leibler divergence, etc.).
9+
pub trait EntropyExt<A, S, D>
10+
where
11+
S: Data<Elem = A>,
12+
D: Dimension,
13+
{
14+
/// Computes the [entropy] *S* of the array values, defined as
15+
///
16+
/// ```text
17+
/// n
18+
/// S = - ∑ xᵢ ln(xᵢ)
19+
/// i=1
20+
/// ```
21+
///
22+
/// If the array is empty, `None` is returned.
23+
///
24+
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
25+
///
26+
/// ## Remarks
27+
///
28+
/// The entropy is a measure used in [Information Theory]
29+
/// to describe a probability distribution: it only make sense
30+
/// when the array values sum to 1, with each entry between
31+
/// 0 and 1 (extremes included).
32+
///
33+
/// The array values are **not** normalised by this function before
34+
/// computing the entropy to avoid introducing potentially
35+
/// unnecessary numerical errors (e.g. if the array were to be already normalised).
36+
///
37+
/// By definition, *xᵢ ln(xᵢ)* is set to 0 if *xᵢ* is 0.
38+
///
39+
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
40+
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
41+
fn entropy(&self) -> Option<A>
42+
where
43+
A: Float;
44+
45+
/// Computes the [Kullback-Leibler divergence] *Dₖₗ(p,q)* between two arrays,
46+
/// where `self`=*p*.
47+
///
48+
/// The Kullback-Leibler divergence is defined as:
49+
///
50+
/// ```text
51+
/// n
52+
/// Dₖₗ(p,q) = - ∑ pᵢ ln(qᵢ/pᵢ)
53+
/// i=1
54+
/// ```
55+
///
56+
/// If the arrays are empty, Ok(`None`) is returned.
57+
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
58+
///
59+
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
60+
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
61+
///
62+
/// ## Remarks
63+
///
64+
/// The Kullback-Leibler divergence is a measure used in [Information Theory]
65+
/// to describe the relationship between two probability distribution: it only make sense
66+
/// when each array sums to 1 with entries between 0 and 1 (extremes included).
67+
///
68+
/// The array values are **not** normalised by this function before
69+
/// computing the entropy to avoid introducing potentially
70+
/// unnecessary numerical errors (e.g. if the array were to be already normalised).
71+
///
72+
/// By definition, *pᵢ ln(qᵢ/pᵢ)* is set to 0 if *pᵢ* is 0.
73+
///
74+
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
75+
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
76+
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
77+
where
78+
S2: Data<Elem = A>,
79+
A: Float;
80+
81+
/// Computes the [cross entropy] *H(p,q)* between two arrays,
82+
/// where `self`=*p*.
83+
///
84+
/// The cross entropy is defined as:
85+
///
86+
/// ```text
87+
/// n
88+
/// H(p,q) = - ∑ pᵢ ln(qᵢ)
89+
/// i=1
90+
/// ```
91+
///
92+
/// If the arrays are empty, Ok(`None`) is returned.
93+
/// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
94+
///
95+
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
96+
/// is a panic cause for `A`.
97+
///
98+
/// ## Remarks
99+
///
100+
/// The cross entropy is a measure used in [Information Theory]
101+
/// to describe the relationship between two probability distributions: it only makes sense
102+
/// when each array sums to 1 with entries between 0 and 1 (extremes included).
103+
///
104+
/// The array values are **not** normalised by this function before
105+
/// computing the entropy to avoid introducing potentially
106+
/// unnecessary numerical errors (e.g. if the array were to be already normalised).
107+
///
108+
/// The cross entropy is often used as an objective/loss function in
109+
/// [optimization problems], including [machine learning].
110+
///
111+
/// By definition, *pᵢ ln(qᵢ)* is set to 0 if *pᵢ* is 0.
112+
///
113+
/// [cross entropy]: https://en.wikipedia.org/wiki/Cross-entropy
114+
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
115+
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
116+
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
117+
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
118+
where
119+
S2: Data<Elem = A>,
120+
A: Float;
121+
}
122+
123+
impl<A, S, D> EntropyExt<A, S, D> for ArrayBase<S, D>
124+
where
125+
S: Data<Elem = A>,
126+
D: Dimension,
127+
{
128+
fn entropy(&self) -> Option<A>
129+
where
130+
A: Float,
131+
{
132+
if self.len() == 0 {
133+
None
134+
} else {
135+
let entropy = self
136+
.mapv(|x| {
137+
if x == A::zero() {
138+
A::zero()
139+
} else {
140+
x * x.ln()
141+
}
142+
})
143+
.sum();
144+
Some(-entropy)
145+
}
146+
}
147+
148+
fn kl_divergence<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
149+
where
150+
A: Float,
151+
S2: Data<Elem = A>,
152+
{
153+
if self.len() == 0 {
154+
return Ok(None);
155+
}
156+
if self.shape() != q.shape() {
157+
return Err(ShapeMismatch {
158+
first_shape: self.shape().to_vec(),
159+
second_shape: q.shape().to_vec(),
160+
});
161+
}
162+
163+
let mut temp = Array::zeros(self.raw_dim());
164+
Zip::from(&mut temp)
165+
.and(self)
166+
.and(q)
167+
.apply(|result, &p, &q| {
168+
*result = {
169+
if p == A::zero() {
170+
A::zero()
171+
} else {
172+
p * (q / p).ln()
173+
}
174+
}
175+
});
176+
let kl_divergence = -temp.sum();
177+
Ok(Some(kl_divergence))
178+
}
179+
180+
fn cross_entropy<S2>(&self, q: &ArrayBase<S2, D>) -> Result<Option<A>, ShapeMismatch>
181+
where
182+
S2: Data<Elem = A>,
183+
A: Float,
184+
{
185+
if self.len() == 0 {
186+
return Ok(None);
187+
}
188+
if self.shape() != q.shape() {
189+
return Err(ShapeMismatch {
190+
first_shape: self.shape().to_vec(),
191+
second_shape: q.shape().to_vec(),
192+
});
193+
}
194+
195+
let mut temp = Array::zeros(self.raw_dim());
196+
Zip::from(&mut temp)
197+
.and(self)
198+
.and(q)
199+
.apply(|result, &p, &q| {
200+
*result = {
201+
if p == A::zero() {
202+
A::zero()
203+
} else {
204+
p * q.ln()
205+
}
206+
}
207+
});
208+
let cross_entropy = -temp.sum();
209+
Ok(Some(cross_entropy))
210+
}
211+
}
212+
213+
#[cfg(test)]
214+
mod tests {
215+
use super::EntropyExt;
216+
use approx::assert_abs_diff_eq;
217+
use errors::ShapeMismatch;
218+
use ndarray::{array, Array1};
219+
use noisy_float::types::n64;
220+
use std::f64;
221+
222+
#[test]
223+
fn test_entropy_with_nan_values() {
224+
let a = array![f64::NAN, 1.];
225+
assert!(a.entropy().unwrap().is_nan());
226+
}
227+
228+
#[test]
229+
fn test_entropy_with_empty_array_of_floats() {
230+
let a: Array1<f64> = array![];
231+
assert!(a.entropy().is_none());
232+
}
233+
234+
#[test]
235+
fn test_entropy_with_array_of_floats() {
236+
// Array of probability values - normalized and positive.
237+
let a: Array1<f64> = array![
238+
0.03602474, 0.01900344, 0.03510129, 0.03414964, 0.00525311, 0.03368976, 0.00065396,
239+
0.02906146, 0.00063687, 0.01597306, 0.00787625, 0.00208243, 0.01450896, 0.01803418,
240+
0.02055336, 0.03029759, 0.03323628, 0.01218822, 0.0001873, 0.01734179, 0.03521668,
241+
0.02564429, 0.02421992, 0.03540229, 0.03497635, 0.03582331, 0.026558, 0.02460495,
242+
0.02437716, 0.01212838, 0.00058464, 0.00335236, 0.02146745, 0.00930306, 0.01821588,
243+
0.02381928, 0.02055073, 0.01483779, 0.02284741, 0.02251385, 0.00976694, 0.02864634,
244+
0.00802828, 0.03464088, 0.03557152, 0.01398894, 0.01831756, 0.0227171, 0.00736204,
245+
0.01866295,
246+
];
247+
// Computed using scipy.stats.entropy
248+
let expected_entropy = 3.721606155686918;
249+
250+
assert_abs_diff_eq!(a.entropy().unwrap(), expected_entropy, epsilon = 1e-6);
251+
}
252+
253+
#[test]
254+
fn test_cross_entropy_and_kl_with_nan_values() -> Result<(), ShapeMismatch> {
255+
let a = array![f64::NAN, 1.];
256+
let b = array![2., 1.];
257+
assert!(a.cross_entropy(&b)?.unwrap().is_nan());
258+
assert!(b.cross_entropy(&a)?.unwrap().is_nan());
259+
assert!(a.kl_divergence(&b)?.unwrap().is_nan());
260+
assert!(b.kl_divergence(&a)?.unwrap().is_nan());
261+
Ok(())
262+
}
263+
264+
#[test]
265+
fn test_cross_entropy_and_kl_with_same_n_dimension_but_different_n_elements() {
266+
let p = array![f64::NAN, 1.];
267+
let q = array![2., 1., 5.];
268+
assert!(q.cross_entropy(&p).is_err());
269+
assert!(p.cross_entropy(&q).is_err());
270+
assert!(q.kl_divergence(&p).is_err());
271+
assert!(p.kl_divergence(&q).is_err());
272+
}
273+
274+
#[test]
275+
fn test_cross_entropy_and_kl_with_different_shape_but_same_n_elements() {
276+
// p: 3x2, 6 elements
277+
let p = array![[f64::NAN, 1.], [6., 7.], [10., 20.]];
278+
// q: 2x3, 6 elements
279+
let q = array![[2., 1., 5.], [1., 1., 7.],];
280+
assert!(q.cross_entropy(&p).is_err());
281+
assert!(p.cross_entropy(&q).is_err());
282+
assert!(q.kl_divergence(&p).is_err());
283+
assert!(p.kl_divergence(&q).is_err());
284+
}
285+
286+
#[test]
287+
fn test_cross_entropy_and_kl_with_empty_array_of_floats() -> Result<(), ShapeMismatch> {
288+
let p: Array1<f64> = array![];
289+
let q: Array1<f64> = array![];
290+
assert!(p.cross_entropy(&q)?.is_none());
291+
assert!(p.kl_divergence(&q)?.is_none());
292+
Ok(())
293+
}
294+
295+
#[test]
296+
fn test_cross_entropy_and_kl_with_negative_qs() -> Result<(), ShapeMismatch> {
297+
let p = array![1.];
298+
let q = array![-1.];
299+
let cross_entropy: f64 = p.cross_entropy(&q)?.unwrap();
300+
let kl_divergence: f64 = p.kl_divergence(&q)?.unwrap();
301+
assert!(cross_entropy.is_nan());
302+
assert!(kl_divergence.is_nan());
303+
Ok(())
304+
}
305+
306+
#[test]
307+
#[should_panic]
308+
fn test_cross_entropy_with_noisy_negative_qs() {
309+
let p = array![n64(1.)];
310+
let q = array![n64(-1.)];
311+
let _ = p.cross_entropy(&q);
312+
}
313+
314+
#[test]
315+
#[should_panic]
316+
fn test_kl_with_noisy_negative_qs() {
317+
let p = array![n64(1.)];
318+
let q = array![n64(-1.)];
319+
let _ = p.kl_divergence(&q);
320+
}
321+
322+
#[test]
323+
fn test_cross_entropy_and_kl_with_zeroes_p() -> Result<(), ShapeMismatch> {
324+
let p = array![0., 0.];
325+
let q = array![0., 0.5];
326+
assert_eq!(p.cross_entropy(&q)?.unwrap(), 0.);
327+
assert_eq!(p.kl_divergence(&q)?.unwrap(), 0.);
328+
Ok(())
329+
}
330+
331+
#[test]
332+
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership(
333+
) -> Result<(), ShapeMismatch> {
334+
let p = array![0.5, 0.5];
335+
let mut q = array![0.5, 0.];
336+
assert_eq!(p.cross_entropy(&q.view_mut())?.unwrap(), f64::INFINITY);
337+
assert_eq!(p.kl_divergence(&q.view_mut())?.unwrap(), f64::INFINITY);
338+
Ok(())
339+
}
340+
341+
#[test]
342+
fn test_cross_entropy() -> Result<(), ShapeMismatch> {
343+
// Arrays of probability values - normalized and positive.
344+
let p: Array1<f64> = array![
345+
0.05340169, 0.02508511, 0.03460454, 0.00352313, 0.07837615, 0.05859495, 0.05782189,
346+
0.0471258, 0.05594036, 0.01630048, 0.07085162, 0.05365855, 0.01959158, 0.05020174,
347+
0.03801479, 0.00092234, 0.08515856, 0.00580683, 0.0156542, 0.0860375, 0.0724246,
348+
0.00727477, 0.01004402, 0.01854399, 0.03504082,
349+
];
350+
let q: Array1<f64> = array![
351+
0.06622616, 0.0478948, 0.03227816, 0.06460884, 0.05795974, 0.01377489, 0.05604812,
352+
0.01202684, 0.01647579, 0.03392697, 0.01656126, 0.00867528, 0.0625685, 0.07381292,
353+
0.05489067, 0.01385491, 0.03639174, 0.00511611, 0.05700415, 0.05183825, 0.06703064,
354+
0.01813342, 0.0007763, 0.0735472, 0.05857833,
355+
];
356+
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
357+
let expected_cross_entropy = 3.385347705020779;
358+
359+
assert_abs_diff_eq!(
360+
p.cross_entropy(&q)?.unwrap(),
361+
expected_cross_entropy,
362+
epsilon = 1e-6
363+
);
364+
Ok(())
365+
}
366+
367+
#[test]
368+
fn test_kl() -> Result<(), ShapeMismatch> {
369+
// Arrays of probability values - normalized and positive.
370+
let p: Array1<f64> = array![
371+
0.00150472, 0.01388706, 0.03495376, 0.03264211, 0.03067355, 0.02183501, 0.00137516,
372+
0.02213802, 0.02745017, 0.02163975, 0.0324602, 0.03622766, 0.00782343, 0.00222498,
373+
0.03028156, 0.02346124, 0.00071105, 0.00794496, 0.0127609, 0.02899124, 0.01281487,
374+
0.0230803, 0.01531864, 0.00518158, 0.02233383, 0.0220279, 0.03196097, 0.03710063,
375+
0.01817856, 0.03524661, 0.02902393, 0.00853364, 0.01255615, 0.03556958, 0.00400151,
376+
0.01335932, 0.01864965, 0.02371322, 0.02026543, 0.0035375, 0.01988341, 0.02621831,
377+
0.03564644, 0.01389121, 0.03151622, 0.03195532, 0.00717521, 0.03547256, 0.00371394,
378+
0.01108706,
379+
];
380+
let q: Array1<f64> = array![
381+
0.02038386, 0.03143914, 0.02630206, 0.0171595, 0.0067072, 0.00911324, 0.02635717,
382+
0.01269113, 0.0302361, 0.02243133, 0.01902902, 0.01297185, 0.02118908, 0.03309548,
383+
0.01266687, 0.0184529, 0.01830936, 0.03430437, 0.02898924, 0.02238251, 0.0139771,
384+
0.01879774, 0.02396583, 0.03019978, 0.01421278, 0.02078981, 0.03542451, 0.02887438,
385+
0.01261783, 0.01014241, 0.03263407, 0.0095969, 0.01923903, 0.0051315, 0.00924686,
386+
0.00148845, 0.00341391, 0.01480373, 0.01920798, 0.03519871, 0.03315135, 0.02099325,
387+
0.03251755, 0.00337555, 0.03432165, 0.01763753, 0.02038337, 0.01923023, 0.01438769,
388+
0.02082707,
389+
];
390+
// Computed using scipy.stats.entropy(p, q)
391+
let expected_kl = 0.3555862567800096;
392+
393+
assert_abs_diff_eq!(p.kl_divergence(&q)?.unwrap(), expected_kl, epsilon = 1e-6);
394+
Ok(())
395+
}
396+
}

0 commit comments

Comments
 (0)