|  | 
| 2 | 2 | ///! | 
| 3 | 3 | ///! This module computes the k largest/smallest singular values/vectors for a dense matrix. | 
| 4 | 4 | use crate::{ | 
|  | 5 | +    eigh::{EigSort, Eigh}, | 
| 5 | 6 |     lobpcg::{lobpcg, random, Lobpcg}, | 
| 6 | 7 |     Order, Result, | 
| 7 | 8 | }; | 
| @@ -35,14 +36,14 @@ impl<A: NdFloat + 'static + MagnitudeCorrection> TruncatedSvdResult<A> { | 
| 35 | 36 |             a.sort_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap().reverse()); | 
| 36 | 37 | 
 | 
| 37 | 38 |             // calculate cut-off magnitude (borrowed from scipy) | 
| 38 |  | -            let cutoff = A::epsilon() * // float precision | 
| 39 |  | -                         A::correction() * // correction term (see trait below) | 
| 40 |  | -                         *a[0].1; // max eigenvalue | 
|  | 39 | +            //let cutoff = A::epsilon() * // float precision | 
|  | 40 | +            //             A::correction() * // correction term (see trait below) | 
|  | 41 | +            //             *a[0].1; // max eigenvalue | 
| 41 | 42 | 
 | 
| 42 | 43 |             // filter low singular values away | 
| 43 | 44 |             let (values, indices): (Vec<A>, Vec<usize>) = a | 
| 44 | 45 |                 .into_iter() | 
| 45 |  | -                .filter(|(_, x)| *x > &cutoff) | 
|  | 46 | +                //.filter(|(_, x)| *x > &cutoff) | 
| 46 | 47 |                 .map(|(a, b)| (b.sqrt(), a)) | 
| 47 | 48 |                 .unzip(); | 
| 48 | 49 | 
 | 
| @@ -188,6 +189,31 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> { | 
| 188 | 189 |         } | 
| 189 | 190 | 
 | 
| 190 | 191 |         let (n, m) = (self.problem.nrows(), self.problem.ncols()); | 
|  | 192 | +        let ngm = n > m; | 
|  | 193 | + | 
|  | 194 | +        // use dense eigenproblem solver if more than 1/5 eigenvalues requested | 
|  | 195 | +        if num * 5 > n.min(m) { | 
|  | 196 | +            let problem = if ngm { | 
|  | 197 | +                self.problem.t().dot(&self.problem) | 
|  | 198 | +            } else { | 
|  | 199 | +                self.problem.dot(&self.problem.t()) | 
|  | 200 | +            }; | 
|  | 201 | + | 
|  | 202 | +            let (eigvals, eigvecs) = problem.eigh()?.sort_eig(self.order); | 
|  | 203 | + | 
|  | 204 | +            let (eigvals, eigvecs) = ( | 
|  | 205 | +                eigvals.slice_move(s![..num]), | 
|  | 206 | +                eigvecs.slice_move(s![..num, ..]), | 
|  | 207 | +            ); | 
|  | 208 | + | 
|  | 209 | +            return Ok(TruncatedSvdResult { | 
|  | 210 | +                eigvals, | 
|  | 211 | +                eigvecs, | 
|  | 212 | +                problem: self.problem, | 
|  | 213 | +                order: self.order, | 
|  | 214 | +                ngm, | 
|  | 215 | +            }); | 
|  | 216 | +        } | 
| 191 | 217 | 
 | 
| 192 | 218 |         // generate initial matrix | 
| 193 | 219 |         let x: Array2<f32> = random((usize::min(n, m), num), &mut self.rng); | 
| @@ -234,7 +260,7 @@ impl<A: NdFloat + Sum, R: Rng> TruncatedSvd<A, R> { | 
| 234 | 260 |                 eigvals, | 
| 235 | 261 |                 eigvecs, | 
| 236 | 262 |                 order: self.order, | 
| 237 |  | -                ngm: n > m, | 
|  | 263 | +                ngm, | 
| 238 | 264 |             }), | 
| 239 | 265 |             Err((err, None)) => Err(err), | 
| 240 | 266 |         } | 
|  | 
0 commit comments