|
2 | 2 | ///!
|
3 | 3 | ///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
|
4 | 4 | use crate::{
|
| 5 | + eigh::{EigSort, Eigh}, |
5 | 6 | lobpcg::{lobpcg, random, Lobpcg},
|
6 | 7 | Order, Result,
|
7 | 8 | };
|
@@ -35,14 +36,14 @@ impl<A: NdFloat + 'static + MagnitudeCorrection> TruncatedSvdResult<A> {
|
35 | 36 | a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse());
|
36 | 37 |
|
37 | 38 | // calculate cut-off magnitude (borrowed from scipy)
|
38 |
| - let cutoff = A::epsilon() * // float precision |
39 |
| - A::correction() * // correction term (see trait below) |
40 |
| - *a[0].1; // max eigenvalue |
| 39 | + //let cutoff = A::epsilon() * // float precision |
| 40 | + // A::correction() * // correction term (see trait below) |
| 41 | + // *a[0].1; // max eigenvalue |
41 | 42 |
|
42 | 43 | // filter low singular values away
|
43 | 44 | let (values, indices): (Vec<A>, Vec<usize>) = a
|
44 | 45 | .into_iter()
|
45 |
| - .filter(|(_, x)| *x > &cutoff) |
| 46 | + //.filter(|(_, x)| *x > &cutoff) |
46 | 47 | .map(|(a, b)| (b.sqrt(), a))
|
47 | 48 | .unzip();
|
48 | 49 |
|
@@ -188,6 +189,31 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
|
188 | 189 | }
|
189 | 190 |
|
190 | 191 | let (n, m) = (self.problem.nrows(), self.problem.ncols());
|
| 192 | + let ngm = n > m; |
| 193 | + |
| 194 | + // use dense eigenproblem solver if more than 1/5 eigenvalues requested |
| 195 | + if num * 5 > n.min(m) { |
| 196 | + let problem = if ngm { |
| 197 | + self.problem.t().dot(&self.problem) |
| 198 | + } else { |
| 199 | + self.problem.dot(&self.problem.t()) |
| 200 | + }; |
| 201 | + |
| 202 | + let (eigvals, eigvecs) = problem.eigh()?.sort_eig(self.order); |
| 203 | + |
| 204 | + let (eigvals, eigvecs) = ( |
| 205 | + eigvals.slice_move(s![..num]), |
| 206 | + eigvecs.slice_move(s![..num, ..]), |
| 207 | + ); |
| 208 | + |
| 209 | + return Ok(TruncatedSvdResult { |
| 210 | + eigvals, |
| 211 | + eigvecs, |
| 212 | + problem: self.problem, |
| 213 | + order: self.order, |
| 214 | + ngm, |
| 215 | + }); |
| 216 | + } |
191 | 217 |
|
192 | 218 | // generate initial matrix
|
193 | 219 | let x: Array2<f32> = random((usize::min(n, m), num), &mut self.rng);
|
@@ -234,7 +260,7 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> {
|
234 | 260 | eigvals,
|
235 | 261 | eigvecs,
|
236 | 262 | order: self.order,
|
237 |
| - ngm: n > m, |
| 263 | + ngm, |
238 | 264 | }),
|
239 | 265 | Err((err, None)) => Err(err),
|
240 | 266 | }
|
|
0 commit comments