@@ -131,14 +131,15 @@ where
131
131
{
132
132
let observation_axis = Axis ( 1 ) ;
133
133
let n_observations = A :: from_usize ( self . len_of ( observation_axis) ) . unwrap ( ) ;
134
- let dof =
135
- if ddof >= n_observations {
136
- panic ! ( "`ddof` needs to be strictly smaller than the \
137
- number of observations provided for each \
138
- random variable!")
139
- } else {
140
- n_observations - ddof
141
- } ;
134
+ let dof = if ddof >= n_observations {
135
+ panic ! (
136
+ "`ddof` needs to be strictly smaller than the \
137
+ number of observations provided for each \
138
+ random variable!"
139
+ )
140
+ } else {
141
+ n_observations - ddof
142
+ } ;
142
143
let mean = self . mean_axis ( observation_axis) ;
143
144
let denoised = self - & mean. insert_axis ( observation_axis) ;
144
145
let covariance = denoised. dot ( & denoised. t ( ) ) ;
@@ -156,7 +157,9 @@ where
156
157
// observation per random variable (or no observations at all)
157
158
let ddof = -A :: one ( ) ;
158
159
let cov = self . cov ( ddof) ;
159
- let std = self . std_axis ( observation_axis, ddof) . insert_axis ( observation_axis) ;
160
+ let std = self
161
+ . std_axis ( observation_axis, ddof)
162
+ . insert_axis ( observation_axis) ;
160
163
let std_matrix = std. dot ( & std. t ( ) ) ;
161
164
// element-wise division
162
165
cov / std_matrix
@@ -167,10 +170,10 @@ where
167
170
mod cov_tests {
168
171
use super :: * ;
169
172
use ndarray:: array;
173
+ use ndarray_rand:: RandomExt ;
170
174
use quickcheck:: quickcheck;
171
175
use rand;
172
176
use rand:: distributions:: Uniform ;
173
- use ndarray_rand:: RandomExt ;
174
177
175
178
quickcheck ! {
176
179
fn constant_random_variables_have_zero_covariance_matrix( value: f64 ) -> bool {
@@ -200,10 +203,7 @@ mod cov_tests {
200
203
fn test_invalid_ddof ( ) {
201
204
let n_random_variables = 3 ;
202
205
let n_observations = 4 ;
203
- let a = Array :: random (
204
- ( n_random_variables, n_observations) ,
205
- Uniform :: new ( 0. , 10. )
206
- ) ;
206
+ let a = Array :: random ( ( n_random_variables, n_observations) , Uniform :: new ( 0. , 10. ) ) ;
207
207
let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
208
208
a. cov ( invalid_ddof) ;
209
209
}
@@ -235,55 +235,46 @@ mod cov_tests {
235
235
#[ test]
236
236
fn test_covariance_for_random_array ( ) {
237
237
let a = array ! [
238
- [ 0.72009497 , 0.12568055 , 0.55705966 , 0.5959984 , 0.69471457 ] ,
239
- [ 0.56717131 , 0.47619486 , 0.21526298 , 0.88915366 , 0.91971245 ] ,
240
- [ 0.59044195 , 0.10720363 , 0.76573717 , 0.54693675 , 0.95923036 ] ,
241
- [ 0.24102952 , 0.131347 , 0.11118028 , 0.21451351 , 0.30515539 ] ,
242
- [ 0.26952473 , 0.93079841 , 0.8080893 , 0.42814155 , 0.24642258 ]
238
+ [ 0.72009497 , 0.12568055 , 0.55705966 , 0.5959984 , 0.69471457 ] ,
239
+ [ 0.56717131 , 0.47619486 , 0.21526298 , 0.88915366 , 0.91971245 ] ,
240
+ [ 0.59044195 , 0.10720363 , 0.76573717 , 0.54693675 , 0.95923036 ] ,
241
+ [ 0.24102952 , 0.131347 , 0.11118028 , 0.21451351 , 0.30515539 ] ,
242
+ [ 0.26952473 , 0.93079841 , 0.8080893 , 0.42814155 , 0.24642258 ]
243
243
] ;
244
244
let numpy_covariance = array ! [
245
- [ 0.05786248 , 0.02614063 , 0.06446215 , 0.01285105 , -0.06443992 ] ,
246
- [ 0.02614063 , 0.08733569 , 0.02436933 , 0.01977437 , -0.06715555 ] ,
247
- [ 0.06446215 , 0.02436933 , 0.10052129 , 0.01393589 , -0.06129912 ] ,
248
- [ 0.01285105 , 0.01977437 , 0.01393589 , 0.00638795 , -0.02355557 ] ,
249
- [ -0.06443992 , -0.06715555 , -0.06129912 , -0.02355557 , 0.09909855 ]
245
+ [ 0.05786248 , 0.02614063 , 0.06446215 , 0.01285105 , -0.06443992 ] ,
246
+ [ 0.02614063 , 0.08733569 , 0.02436933 , 0.01977437 , -0.06715555 ] ,
247
+ [ 0.06446215 , 0.02436933 , 0.10052129 , 0.01393589 , -0.06129912 ] ,
248
+ [ 0.01285105 , 0.01977437 , 0.01393589 , 0.00638795 , -0.02355557 ] ,
249
+ [
250
+ -0.06443992 ,
251
+ -0.06715555 ,
252
+ -0.06129912 ,
253
+ -0.02355557 ,
254
+ 0.09909855
255
+ ]
250
256
] ;
251
257
assert_eq ! ( a. ndim( ) , 2 ) ;
252
- assert ! (
253
- a. cov( 1. ) . all_close(
254
- & numpy_covariance,
255
- 1e-8
256
- )
257
- ) ;
258
+ assert ! ( a. cov( 1. ) . all_close( & numpy_covariance, 1e-8 ) ) ;
258
259
}
259
260
260
261
#[ test]
261
262
#[ should_panic]
262
263
// We lose precision, hence the failing assert
263
264
fn test_covariance_for_badly_conditioned_array ( ) {
264
- let a: Array2 < f64 > = array ! [
265
- [ 1e12 + 1. , 1e12 - 1. ] ,
266
- [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] ,
267
- ] ;
268
- let expected_covariance = array ! [
269
- [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ]
270
- ] ;
271
- assert ! (
272
- a. cov( 1. ) . all_close(
273
- & expected_covariance,
274
- 1e-24
275
- )
276
- ) ;
265
+ let a: Array2 < f64 > = array ! [ [ 1e12 + 1. , 1e12 - 1. ] , [ 1e-6 + 1e-12 , 1e-6 - 1e-12 ] , ] ;
266
+ let expected_covariance = array ! [ [ 2. , 2e-12 ] , [ 2e-12 , 2e-24 ] ] ;
267
+ assert ! ( a. cov( 1. ) . all_close( & expected_covariance, 1e-24 ) ) ;
277
268
}
278
269
}
279
270
280
271
#[ cfg( test) ]
281
272
mod pearson_correlation_tests {
282
273
use super :: * ;
283
274
use ndarray:: array;
275
+ use ndarray_rand:: RandomExt ;
284
276
use quickcheck:: quickcheck;
285
277
use rand:: distributions:: Uniform ;
286
- use ndarray_rand:: RandomExt ;
287
278
288
279
quickcheck ! {
289
280
fn output_matrix_is_symmetric( bound: f64 ) -> bool {
@@ -337,19 +328,14 @@ mod pearson_correlation_tests {
337
328
[ 0.26979716 , 0.20887228 , 0.95454999 , 0.96290785 ]
338
329
] ;
339
330
let numpy_corrcoeff = array ! [
340
- [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
341
- [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
342
- [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
343
- [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
344
- [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
331
+ [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
332
+ [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
333
+ [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
334
+ [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
335
+ [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
345
336
] ;
346
337
assert_eq ! ( a. ndim( ) , 2 ) ;
347
- assert ! (
348
- a. pearson_correlation( ) . all_close(
349
- & numpy_corrcoeff,
350
- 1e-7
351
- )
352
- ) ;
338
+ assert ! ( a. pearson_correlation( ) . all_close( & numpy_corrcoeff, 1e-7 ) ) ;
353
339
}
354
340
355
341
}
0 commit comments