1
1
//! Information theory (e.g. entropy, KL divergence, etc.).
2
- use crate :: errors:: ShapeMismatch ;
2
+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
3
3
use ndarray:: { Array , ArrayBase , Data , Dimension , Zip } ;
4
4
use num_traits:: Float ;
5
5
19
19
/// i=1
20
20
/// ```
21
21
///
22
- /// If the array is empty, `None ` is returned.
22
+ /// If the array is empty, `Err(EmptyInput) ` is returned.
23
23
///
24
24
/// **Panics** if `ln` of any element in the array panics (which can occur for negative values for some `A`).
25
25
///
38
38
///
39
39
/// [entropy]: https://en.wikipedia.org/wiki/Entropy_(information_theory)
40
40
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
41
- fn entropy ( & self ) -> Option < A >
41
+ fn entropy ( & self ) -> Result < A , EmptyInput >
42
42
where
43
43
A : Float ;
44
44
53
53
/// i=1
54
54
/// ```
55
55
///
56
- /// If the arrays are empty, Ok(`None`) is returned.
57
- /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
56
+ /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
57
+ /// If the array shapes are not identical,
58
+ /// `Err(MultiInputError::ShapeMismatch)` is returned.
58
59
///
59
60
/// **Panics** if, for a pair of elements *(pᵢ, qᵢ)* from *p* and *q*, computing
60
61
/// *ln(qᵢ/pᵢ)* is a panic cause for `A`.
73
74
///
74
75
/// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
75
76
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
76
- fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
77
+ fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
77
78
where
78
79
S2 : Data < Elem = A > ,
79
80
A : Float ;
89
90
/// i=1
90
91
/// ```
91
92
///
92
- /// If the arrays are empty, Ok(`None`) is returned.
93
- /// If the array shapes are not identical, `Err(ShapeMismatch)` is returned.
93
+ /// If the arrays are empty, `Err(MultiInputError::EmptyInput)` is returned.
94
+ /// If the array shapes are not identical,
95
+ /// `Err(MultiInputError::ShapeMismatch)` is returned.
94
96
///
95
97
/// **Panics** if any element in *q* is negative and taking the logarithm of a negative number
96
98
/// is a panic cause for `A`.
@@ -114,7 +116,7 @@ where
114
116
/// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory
115
117
/// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method
116
118
/// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression
117
- fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
119
+ fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
118
120
where
119
121
S2 : Data < Elem = A > ,
120
122
A : Float ;
@@ -125,14 +127,14 @@ where
125
127
S : Data < Elem = A > ,
126
128
D : Dimension ,
127
129
{
128
- fn entropy ( & self ) -> Option < A >
130
+ fn entropy ( & self ) -> Result < A , EmptyInput >
129
131
where
130
132
A : Float ,
131
133
{
132
134
if self . len ( ) == 0 {
133
- None
135
+ Err ( EmptyInput )
134
136
} else {
135
- let entropy = self
137
+ let entropy = - self
136
138
. mapv ( |x| {
137
139
if x == A :: zero ( ) {
138
140
A :: zero ( )
@@ -141,23 +143,24 @@ where
141
143
}
142
144
} )
143
145
. sum ( ) ;
144
- Some ( - entropy)
146
+ Ok ( entropy)
145
147
}
146
148
}
147
149
148
- fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
150
+ fn kl_divergence < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
149
151
where
150
152
A : Float ,
151
153
S2 : Data < Elem = A > ,
152
154
{
153
155
if self . len ( ) == 0 {
154
- return Ok ( None ) ;
156
+ return Err ( MultiInputError :: EmptyInput ) ;
155
157
}
156
158
if self . shape ( ) != q. shape ( ) {
157
159
return Err ( ShapeMismatch {
158
160
first_shape : self . shape ( ) . to_vec ( ) ,
159
161
second_shape : q. shape ( ) . to_vec ( ) ,
160
- } ) ;
162
+ }
163
+ . into ( ) ) ;
161
164
}
162
165
163
166
let mut temp = Array :: zeros ( self . raw_dim ( ) ) ;
@@ -174,22 +177,23 @@ where
174
177
}
175
178
} ) ;
176
179
let kl_divergence = -temp. sum ( ) ;
177
- Ok ( Some ( kl_divergence) )
180
+ Ok ( kl_divergence)
178
181
}
179
182
180
- fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < Option < A > , ShapeMismatch >
183
+ fn cross_entropy < S2 > ( & self , q : & ArrayBase < S2 , D > ) -> Result < A , MultiInputError >
181
184
where
182
185
S2 : Data < Elem = A > ,
183
186
A : Float ,
184
187
{
185
188
if self . len ( ) == 0 {
186
- return Ok ( None ) ;
189
+ return Err ( MultiInputError :: EmptyInput ) ;
187
190
}
188
191
if self . shape ( ) != q. shape ( ) {
189
192
return Err ( ShapeMismatch {
190
193
first_shape : self . shape ( ) . to_vec ( ) ,
191
194
second_shape : q. shape ( ) . to_vec ( ) ,
192
- } ) ;
195
+ }
196
+ . into ( ) ) ;
193
197
}
194
198
195
199
let mut temp = Array :: zeros ( self . raw_dim ( ) ) ;
@@ -206,15 +210,15 @@ where
206
210
}
207
211
} ) ;
208
212
let cross_entropy = -temp. sum ( ) ;
209
- Ok ( Some ( cross_entropy) )
213
+ Ok ( cross_entropy)
210
214
}
211
215
}
212
216
213
217
#[ cfg( test) ]
214
218
mod tests {
215
219
use super :: EntropyExt ;
216
220
use approx:: assert_abs_diff_eq;
217
- use errors:: ShapeMismatch ;
221
+ use errors:: { EmptyInput , MultiInputError } ;
218
222
use ndarray:: { array, Array1 } ;
219
223
use noisy_float:: types:: n64;
220
224
use std:: f64;
@@ -228,7 +232,7 @@ mod tests {
228
232
#[ test]
229
233
fn test_entropy_with_empty_array_of_floats ( ) {
230
234
let a: Array1 < f64 > = array ! [ ] ;
231
- assert ! ( a. entropy( ) . is_none ( ) ) ;
235
+ assert_eq ! ( a. entropy( ) , Err ( EmptyInput ) ) ;
232
236
}
233
237
234
238
#[ test]
@@ -251,13 +255,13 @@ mod tests {
251
255
}
252
256
253
257
#[ test]
254
- fn test_cross_entropy_and_kl_with_nan_values ( ) -> Result < ( ) , ShapeMismatch > {
258
+ fn test_cross_entropy_and_kl_with_nan_values ( ) -> Result < ( ) , MultiInputError > {
255
259
let a = array ! [ f64 :: NAN , 1. ] ;
256
260
let b = array ! [ 2. , 1. ] ;
257
- assert ! ( a. cross_entropy( & b) ?. unwrap ( ) . is_nan( ) ) ;
258
- assert ! ( b. cross_entropy( & a) ?. unwrap ( ) . is_nan( ) ) ;
259
- assert ! ( a. kl_divergence( & b) ?. unwrap ( ) . is_nan( ) ) ;
260
- assert ! ( b. kl_divergence( & a) ?. unwrap ( ) . is_nan( ) ) ;
261
+ assert ! ( a. cross_entropy( & b) ?. is_nan( ) ) ;
262
+ assert ! ( b. cross_entropy( & a) ?. is_nan( ) ) ;
263
+ assert ! ( a. kl_divergence( & b) ?. is_nan( ) ) ;
264
+ assert ! ( b. kl_divergence( & a) ?. is_nan( ) ) ;
261
265
Ok ( ( ) )
262
266
}
263
267
@@ -284,20 +288,19 @@ mod tests {
284
288
}
285
289
286
290
#[ test]
287
- fn test_cross_entropy_and_kl_with_empty_array_of_floats ( ) -> Result < ( ) , ShapeMismatch > {
291
+ fn test_cross_entropy_and_kl_with_empty_array_of_floats ( ) {
288
292
let p: Array1 < f64 > = array ! [ ] ;
289
293
let q: Array1 < f64 > = array ! [ ] ;
290
- assert ! ( p. cross_entropy( & q) ?. is_none( ) ) ;
291
- assert ! ( p. kl_divergence( & q) ?. is_none( ) ) ;
292
- Ok ( ( ) )
294
+ assert ! ( p. cross_entropy( & q) . unwrap_err( ) . is_empty_input( ) ) ;
295
+ assert ! ( p. kl_divergence( & q) . unwrap_err( ) . is_empty_input( ) ) ;
293
296
}
294
297
295
298
#[ test]
296
- fn test_cross_entropy_and_kl_with_negative_qs ( ) -> Result < ( ) , ShapeMismatch > {
299
+ fn test_cross_entropy_and_kl_with_negative_qs ( ) -> Result < ( ) , MultiInputError > {
297
300
let p = array ! [ 1. ] ;
298
301
let q = array ! [ -1. ] ;
299
- let cross_entropy: f64 = p. cross_entropy ( & q) ?. unwrap ( ) ;
300
- let kl_divergence: f64 = p. kl_divergence ( & q) ?. unwrap ( ) ;
302
+ let cross_entropy: f64 = p. cross_entropy ( & q) ?;
303
+ let kl_divergence: f64 = p. kl_divergence ( & q) ?;
301
304
assert ! ( cross_entropy. is_nan( ) ) ;
302
305
assert ! ( kl_divergence. is_nan( ) ) ;
303
306
Ok ( ( ) )
@@ -320,26 +323,26 @@ mod tests {
320
323
}
321
324
322
325
#[ test]
323
- fn test_cross_entropy_and_kl_with_zeroes_p ( ) -> Result < ( ) , ShapeMismatch > {
326
+ fn test_cross_entropy_and_kl_with_zeroes_p ( ) -> Result < ( ) , MultiInputError > {
324
327
let p = array ! [ 0. , 0. ] ;
325
328
let q = array ! [ 0. , 0.5 ] ;
326
- assert_eq ! ( p. cross_entropy( & q) ?. unwrap ( ) , 0. ) ;
327
- assert_eq ! ( p. kl_divergence( & q) ?. unwrap ( ) , 0. ) ;
329
+ assert_eq ! ( p. cross_entropy( & q) ?, 0. ) ;
330
+ assert_eq ! ( p. kl_divergence( & q) ?, 0. ) ;
328
331
Ok ( ( ) )
329
332
}
330
333
331
334
#[ test]
332
335
fn test_cross_entropy_and_kl_with_zeroes_q_and_different_data_ownership (
333
- ) -> Result < ( ) , ShapeMismatch > {
336
+ ) -> Result < ( ) , MultiInputError > {
334
337
let p = array ! [ 0.5 , 0.5 ] ;
335
338
let mut q = array ! [ 0.5 , 0. ] ;
336
- assert_eq ! ( p. cross_entropy( & q. view_mut( ) ) ?. unwrap ( ) , f64 :: INFINITY ) ;
337
- assert_eq ! ( p. kl_divergence( & q. view_mut( ) ) ?. unwrap ( ) , f64 :: INFINITY ) ;
339
+ assert_eq ! ( p. cross_entropy( & q. view_mut( ) ) ?, f64 :: INFINITY ) ;
340
+ assert_eq ! ( p. kl_divergence( & q. view_mut( ) ) ?, f64 :: INFINITY ) ;
338
341
Ok ( ( ) )
339
342
}
340
343
341
344
#[ test]
342
- fn test_cross_entropy ( ) -> Result < ( ) , ShapeMismatch > {
345
+ fn test_cross_entropy ( ) -> Result < ( ) , MultiInputError > {
343
346
// Arrays of probability values - normalized and positive.
344
347
let p: Array1 < f64 > = array ! [
345
348
0.05340169 , 0.02508511 , 0.03460454 , 0.00352313 , 0.07837615 , 0.05859495 , 0.05782189 ,
@@ -356,16 +359,12 @@ mod tests {
356
359
// Computed using scipy.stats.entropy(p) + scipy.stats.entropy(p, q)
357
360
let expected_cross_entropy = 3.385347705020779 ;
358
361
359
- assert_abs_diff_eq ! (
360
- p. cross_entropy( & q) ?. unwrap( ) ,
361
- expected_cross_entropy,
362
- epsilon = 1e-6
363
- ) ;
362
+ assert_abs_diff_eq ! ( p. cross_entropy( & q) ?, expected_cross_entropy, epsilon = 1e-6 ) ;
364
363
Ok ( ( ) )
365
364
}
366
365
367
366
#[ test]
368
- fn test_kl ( ) -> Result < ( ) , ShapeMismatch > {
367
+ fn test_kl ( ) -> Result < ( ) , MultiInputError > {
369
368
// Arrays of probability values - normalized and positive.
370
369
let p: Array1 < f64 > = array ! [
371
370
0.00150472 , 0.01388706 , 0.03495376 , 0.03264211 , 0.03067355 , 0.02183501 , 0.00137516 ,
@@ -390,7 +389,7 @@ mod tests {
390
389
// Computed using scipy.stats.entropy(p, q)
391
390
let expected_kl = 0.3555862567800096 ;
392
391
393
- assert_abs_diff_eq ! ( p. kl_divergence( & q) ?. unwrap ( ) , expected_kl, epsilon = 1e-6 ) ;
392
+ assert_abs_diff_eq ! ( p. kl_divergence( & q) ?, expected_kl, epsilon = 1e-6 ) ;
394
393
Ok ( ( ) )
395
394
}
396
395
}
0 commit comments