Skip to content

Commit 659ecc1

Browse files
committed
Add proptests for lobpcg SVD and found problematic case
1 parent 34c3b6f commit 659ecc1

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

tests/lobpcg.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use rand_xoshiro::Xoshiro256Plus;
77

88
use ndarray_linalg_rs::eigh::*;
99
use ndarray_linalg_rs::lobpcg::*;
10+
use ndarray_linalg_rs::svd::*;
1011

1112
mod common;
1213

@@ -105,3 +106,83 @@ fn problematic_eig_matrix() {
105106
];
106107
run_lobpcg_eig_test(arr, 3, Order::Largest);
107108
}
109+
110+
fn run_lobpcg_svd_test(arr: Array2<f64>, ordering: Order) {
111+
let (_, s, _) = arr
112+
.svd(false, false)
113+
.unwrap()
114+
.sort_svd(ordering == Order::Largest);
115+
let (u, ts, vt) =
116+
TruncatedSvd::new_with_rng(arr.clone(), ordering, Xoshiro256Plus::seed_from_u64(42))
117+
.precision(1e-3)
118+
.maxiter(10)
119+
.decompose(arr.ncols())
120+
.unwrap()
121+
.values_vectors();
122+
123+
assert_abs_diff_eq!(s, ts, epsilon = 1e-5);
124+
assert_abs_diff_eq!(u.dot(&Array2::from_diag(&ts)).dot(&vt), arr, epsilon = 1e-5);
125+
}
126+
127+
proptest! {
128+
#![proptest_config(ProptestConfig::with_cases(256))]
129+
#[test]
130+
fn lobpcg_svd_test(arr in common::hpd_arr(), ordering in generate_order()) {
131+
run_lobpcg_svd_test(arr, ordering);
132+
}
133+
}
134+
135+
#[test]
136+
fn problematic_svd_matrix() {
137+
let arr = array![
138+
[
139+
18703.111084031745,
140+
5398.592802934647,
141+
-2798.4524863262,
142+
3142.0598040221316,
143+
10654.718971270437,
144+
2928.7057369452755
145+
],
146+
[
147+
5398.592802934647,
148+
35574.82803149514,
149+
-29613.112978401838,
150+
-12632.782177317926,
151+
-16546.07166801079,
152+
-13607.176833471722
153+
],
154+
[
155+
-2798.4524863262,
156+
-29613.112978401838,
157+
29022.408309489085,
158+
8718.392706824303,
159+
12376.7396224986,
160+
17995.47911319261
161+
],
162+
[
163+
3142.0598040221316,
164+
-12632.782177317926,
165+
8718.392706824303,
166+
22884.5878990548,
167+
-598.390397885349,
168+
-8629.726579767677
169+
],
170+
[
171+
10654.718971270437,
172+
-16546.07166801079,
173+
12376.7396224986,
174+
-598.390397885349,
175+
27757.334483403938,
176+
15535.051898142627
177+
],
178+
[
179+
2928.7057369452755,
180+
-13607.176833471722,
181+
17995.47911319261,
182+
-8629.726579767677,
183+
15535.051898142627,
184+
31748.677025662313
185+
]
186+
];
187+
run_lobpcg_svd_test(arr, Order::Largest);
188+
}

0 commit comments

Comments
 (0)