Skip to content

Commit 34c3b6f

Browse files
committed
Add proptest for lobpcg eig and found problematic test case
1 parent f8b1395 commit 34c3b6f

File tree

7 files changed

+77
-38
lines changed

7 files changed

+77
-38
lines changed

Cargo.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,3 @@ rand_xoshiro = { version = "0.6" }
2525
[features]
2626
default = ["iterative"]
2727
iterative = ["rand"]
28-
29-
[[test]]
30-
name = "lobpcg"
31-
path = "tests/lobpcg.rs"
32-
required-features = ["iterative"]

src/lobpcg/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ where
4545
/// as it could be of value. If there is no result at all, then the second field is `None`.
4646
/// This happens if the algorithm fails in an early stage, for example if the matrix `A` is not SPD
4747
pub type LobpcgResult<A> = std::result::Result<Lobpcg<A>, (LinalgError, Option<Lobpcg<A>>)>;
48+
49+
#[derive(Debug, Clone, PartialEq)]
4850
pub struct Lobpcg<A> {
49-
eigvals: Array1<A>,
50-
eigvecs: Array2<A>,
51+
pub eigvals: Array1<A>,
52+
pub eigvecs: Array2<A>,
5153
rnorm: Vec<A>,
5254
}

src/lobpcg/svd.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ mod tests {
268268
use super::TruncatedSvd;
269269

270270
use approx::assert_abs_diff_eq;
271-
use ndarray::{arr1, arr2, s, Array1, Array2, NdFloat};
272-
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
271+
use ndarray::{arr1, arr2, Array2, NdFloat};
273272
use rand::distributions::{Distribution, Standard};
274273
use rand::SeedableRng;
275274
use rand_xoshiro::Xoshiro256Plus;

tests/cholesky.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,6 @@ use ndarray_linalg_rs::{cholesky::*, triangular::*};
66

77
mod common;
88

9-
prop_compose! {
10-
fn hpd_arr()
11-
(arr in common::square_arr()) -> Array2<f64> {
12-
let dim = arr.nrows();
13-
let mut mul = arr.t().dot(&arr);
14-
for i in 0..dim {
15-
mul[(i, i)] += 1.0;
16-
}
17-
mul
18-
}
19-
}
20-
219
fn run_cholesky_test(orig: Array2<f64>) {
2210
let chol = orig.cholesky().unwrap();
2311
assert_abs_diff_eq!(chol.dot(&chol.t()), orig, epsilon = 1e-7);
@@ -69,17 +57,17 @@ fn run_invc_test(a: Array2<f64>) {
6957
proptest! {
7058
#![proptest_config(ProptestConfig::with_cases(1000))]
7159
#[test]
72-
fn cholesky_test(arr in hpd_arr()) {
60+
fn cholesky_test(arr in common::hpd_arr()) {
7361
run_cholesky_test(arr)
7462
}
7563

7664
#[test]
77-
fn solvec_test((a, x) in common::system_of_arr(hpd_arr())) {
65+
fn solvec_test((a, x) in common::system_of_arr(common::hpd_arr())) {
7866
run_solvec_test(a, x)
7967
}
8068

8169
#[test]
82-
fn invc_test(arr in hpd_arr()) {
70+
fn invc_test(arr in common::hpd_arr()) {
8371
run_invc_test(arr)
8472
}
8573
}

tests/common.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use std::ops::RangeInclusive;
44

5+
use approx::assert_abs_diff_eq;
56
use ndarray::prelude::*;
67
use proptest::prelude::*;
78
use proptest_derive::Arbitrary;
@@ -49,6 +50,18 @@ prop_compose! {
4950
}
5051
}
5152

53+
prop_compose! {
54+
pub fn hpd_arr()
55+
(arr in square_arr()) -> Array2<f64> {
56+
let dim = arr.nrows();
57+
let mut mul = arr.t().dot(&arr);
58+
for i in 0..dim {
59+
mul[(i, i)] += 1.0;
60+
}
61+
mul
62+
}
63+
}
64+
5265
prop_compose! {
5366
pub fn rect_arr()(rows in DIM_RANGE, cols in DIM_RANGE)
5467
(arr in matrix(rows, cols)) -> Array2<f64> {
@@ -93,3 +106,12 @@ pub fn system_of_arr(
93106
)
94107
})
95108
}
109+
110+
pub fn check_eigh(arr: &Array2<f64>, vals: &Array1<f64>, vecs: &Array2<f64>) {
111+
// Original array multiplied with eigenvec should equal eigenval times eigenvec
112+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
113+
let av = arr.dot(&v);
114+
let ev = v.mapv(|x| vals[i] * x);
115+
assert_abs_diff_eq!(av, ev, epsilon = 1e-5);
116+
}
117+
}

tests/eigh.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@ use ndarray_linalg_rs::eigh::*;
66

77
mod common;
88

9-
fn check_eigh(arr: &Array2<f64>, vals: &Array1<f64>, vecs: &Array2<f64>) {
10-
// Original array multiplied with eigenvec should equal eigenval times eigenvec
11-
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
12-
let av = arr.dot(&v);
13-
let ev = v.mapv(|x| vals[i] * x);
14-
assert_abs_diff_eq!(av, ev, epsilon = 1e-5);
15-
}
16-
}
17-
189
fn run_eigh_test(arr: Array2<f64>) {
1910
let n = arr.nrows();
2011
let d = arr.eigh().unwrap();
@@ -23,7 +14,7 @@ fn run_eigh_test(arr: Array2<f64>) {
2314
// Eigenvecs should be orthogonal
2415
let s = vecs.t().dot(&vecs);
2516
assert_abs_diff_eq!(s, Array2::eye(n), epsilon = 1e-5);
26-
check_eigh(&arr, &vals, &vecs);
17+
common::check_eigh(&arr, &vals, &vecs);
2718

2819
let (evals, evecs) = arr.clone().eigh_into().unwrap();
2920
assert_abs_diff_eq!(evals, vals);
@@ -33,12 +24,12 @@ fn run_eigh_test(arr: Array2<f64>) {
3324

3425
// Check if ascending eigen is actually sorted and valid
3526
let (vals, vecs) = d.clone().sort_eig_asc();
36-
check_eigh(&arr, &vals, &vecs);
27+
common::check_eigh(&arr, &vals, &vecs);
3728
assert!(vals.windows(2).into_iter().all(|w| w[0] <= w[1]));
3829

3930
// Check if descending eigen is actually sorted and valid
4031
let (vals, vecs) = d.sort_eig_desc();
41-
check_eigh(&arr, &vals, &vecs);
32+
common::check_eigh(&arr, &vals, &vecs);
4233
assert!(vals.windows(2).into_iter().all(|w| w[0] >= w[1]));
4334
}
4435

tests/lobpcg.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ use approx::assert_abs_diff_eq;
22
use ndarray::prelude::*;
33
use ndarray_rand::{rand_distr::StandardNormal, RandomExt};
44
use proptest::prelude::*;
5-
6-
use ndarray_linalg_rs::lobpcg::*;
75
use rand::SeedableRng;
86
use rand_xoshiro::Xoshiro256Plus;
97

8+
use ndarray_linalg_rs::eigh::*;
9+
use ndarray_linalg_rs::lobpcg::*;
10+
1011
mod common;
1112

1213
/// Eigenvalue structure in high dimensions
@@ -63,3 +64,44 @@ fn test_marchenko_pastur() {
6364
assert_abs_diff_eq!(mp_law, empirical, epsilon = 0.05);
6465
}
6566
}
67+
68+
fn run_lobpcg_eig_test(arr: Array2<f64>, num: usize, ordering: Order) {
69+
let (eigvals, _) = arr.eigh().unwrap().sort_eig(ordering == Order::Largest);
70+
let res = TruncatedEig::new_with_rng(arr.clone(), ordering, Xoshiro256Plus::seed_from_u64(42))
71+
.precision(1e-3)
72+
.decompose(num)
73+
.unwrap_or_else(|e| e.1.unwrap());
74+
75+
assert_abs_diff_eq!(eigvals.slice(s![..num]), res.eigvals, epsilon = 1e-5);
76+
common::check_eigh(&arr, &res.eigvals, &res.eigvecs);
77+
}
78+
79+
fn generate_order() -> impl Strategy<Value = Order> {
80+
prop_oneof![Just(Order::Largest), Just(Order::Smallest)]
81+
}
82+
83+
prop_compose! {
84+
pub fn hpd_arr_num()(arr in common::hpd_arr())
85+
(num in (1..arr.ncols()), arr in Just(arr)) -> (Array2<f64>, usize) {
86+
(arr, num)
87+
}
88+
}
89+
90+
proptest! {
91+
#![proptest_config(ProptestConfig::with_cases(1000))]
92+
#[test]
93+
fn lobpcg_eig_test((arr, num) in hpd_arr_num(), ordering in generate_order()) {
94+
run_lobpcg_eig_test(arr, num, ordering);
95+
}
96+
}
97+
98+
#[test]
99+
fn problematic_eig_matrix() {
100+
let arr = array![
101+
[1.0, 0.0, 0.0, 0.0],
102+
[0.0, 1.0, 0.0, 0.0],
103+
[0.0, 0.0, 7854.796948298437, 2495.5155877621937],
104+
[0.0, 0.0, 2495.5155877621937, 5995.696530257453]
105+
];
106+
run_lobpcg_eig_test(arr, 3, Order::Largest);
107+
}

0 commit comments

Comments
 (0)