Skip to content

Commit a780910

Browse files
author
Lorenz Schmidt
committed
Use dense eigensolver when eigenvalues exceed 1/5th of problem size
1 parent a0bd205 commit a780910

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/lobpcg/eig.rs

+21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//!
33
use super::random;
44
use crate::{
5+
eigh::{EigSort, Eigh},
56
lobpcg::{lobpcg, Lobpcg, LobpcgResult},
67
Order, Result,
78
};
@@ -148,6 +149,26 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedEig<A, R> {
148149
let x: Array2<f64> = random((self.problem.len_of(Axis(0)), num), &mut self.rng);
149150
let x = x.mapv(|x| NumCast::from(x).unwrap());
150151

152+
// use dense eigenproblem solver if more than 1/5 eigenvalues requested
153+
if num * 5 > self.problem.nrows() {
154+
let (eigvals, eigvecs) = self
155+
.problem
156+
.eigh()
157+
.map_err(|e| (e, None))?
158+
.sort_eig(self.order);
159+
160+
let (eigvals, eigvecs) = (
161+
eigvals.slice_move(s![..num]),
162+
eigvecs.slice_move(s![.., ..num]),
163+
);
164+
165+
return Ok(Lobpcg {
166+
eigvals,
167+
eigvecs,
168+
rnorm: Vec::new(),
169+
});
170+
}
171+
151172
if let Some(ref preconditioner) = self.preconditioner {
152173
lobpcg(
153174
|y| self.problem.dot(&y),

src/lobpcg/svd.rs

+31-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
///!
33
///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
44
use crate::{
5+
eigh::{EigSort, Eigh},
56
lobpcg::{lobpcg, random, Lobpcg},
67
Order, Result,
78
};
@@ -35,14 +36,14 @@ impl<A: NdFloat + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
3536
a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
3637

3738
// calculate cut-off magnitude (borrowed from scipy)
38-
let cutoff = A::epsilon() * // float precision
39-
A::correction() * // correction term (see trait below)
40-
*a[0].1; // max eigenvalue
39+
//let cutoff = A::epsilon() * // float precision
40+
// A::correction() * // correction term (see trait below)
41+
// *a[0].1; // max eigenvalue
4142

4243
// filter low singular values away
4344
let (values, indices): (Vec<A>, Vec<usize>) = a
4445
.into_iter()
45-
.filter(|(_, x)| *x > &cutoff)
46+
//.filter(|(_, x)| *x > &cutoff)
4647
.map(|(a, b)| (b.sqrt(), a))
4748
.unzip();
4849

@@ -188,6 +189,31 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
188189
}
189190

190191
let (n, m) = (self.problem.nrows(), self.problem.ncols());
192+
let ngm = n > m;
193+
194+
// use dense eigenproblem solver if more than 1/5 eigenvalues requested
195+
if num * 5 > n.min(m) {
196+
let problem = if ngm {
197+
self.problem.t().dot(&self.problem)
198+
} else {
199+
self.problem.dot(&self.problem.t())
200+
};
201+
202+
let (eigvals, eigvecs) = problem.eigh()?.sort_eig(self.order);
203+
204+
let (eigvals, eigvecs) = (
205+
eigvals.slice_move(s![..num]),
206+
eigvecs.slice_move(s![..num, ..]),
207+
);
208+
209+
return Ok(TruncatedSvdResult {
210+
eigvals,
211+
eigvecs,
212+
problem: self.problem,
213+
order: self.order,
214+
ngm,
215+
});
216+
}
191217

192218
// generate initial matrix
193219
let x: Array2<f32> = random((usize::min(n, m), num), &mut self.rng);
@@ -234,7 +260,7 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
234260
eigvals,
235261
eigvecs,
236262
order: self.order,
237-
ngm: n > m,
263+
ngm,
238264
}),
239265
Err((err, None)) => Err(err),
240266
}

0 commit comments

Comments
 (0)