Skip to content

Commit 44821c9

Browse files
committed
Bug fix of ldvt
1 parent 9ca911e commit 44821c9

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

src/lapack/svd.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ macro_rules! impl_svd {
4747
let (jvt, ldvt, mut vt) = if calc_vt {
4848
(FlagSVD::All, n, vec![Self::zero(); (n * n) as usize])
4949
} else {
50-
(FlagSVD::No, 1, Vec::new())
50+
(FlagSVD::No, n, Vec::new())
5151
};
5252
let mut s = vec![Self::Real::zero(); k as usize];
5353
let mut superb = vec![Self::Real::zero(); (k - 1) as usize];
54+
dbg!(ldvt);
5455
let info = $gesvd(
5556
l.lapacke_layout(),
5657
ju as u8,
@@ -70,8 +71,8 @@ macro_rules! impl_svd {
7071
info,
7172
SVDOutput {
7273
s: s,
73-
u: if ldu > 0 { Some(u) } else { None },
74-
vt: if ldvt > 0 { Some(vt) } else { None },
74+
u: if calc_u { Some(u) } else { None },
75+
vt: if calc_vt { Some(vt) } else { None },
7576
},
7677
)
7778
}

src/svd.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ where
7575
let l = self.layout()?;
7676
let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? };
7777
let (n, m) = l.size();
78-
let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap());
79-
let vt = svd_res.vt.map(|vt| into_matrix(l.resized(m, m), vt).unwrap());
78+
let u = svd_res
79+
.u
80+
.map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches"));
81+
let vt = svd_res
82+
.vt
83+
.map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches"));
8084
let s = ArrayBase::from_vec(svd_res.s);
8185
Ok((u, s, vt))
8286
}

tests/svd.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn test(a: &Array2<f64>) {
1919
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
2020
}
2121

22-
fn test_u(a: &Array2<f64>) {
22+
fn test_no_vt(a: &Array2<f64>) {
2323
let (n, _m) = a.dim();
2424
println!("a = \n{:?}", a);
2525
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
@@ -30,7 +30,7 @@ fn test_u(a: &Array2<f64>) {
3030
assert_eq!(u.dim().1, n);
3131
}
3232

33-
fn test_vt(a: &Array2<f64>) {
33+
fn test_no_u(a: &Array2<f64>) {
3434
let (_n, m) = a.dim();
3535
println!("a = \n{:?}", a);
3636
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
@@ -60,11 +60,11 @@ macro_rules! test_svd_impl {
6060
}
6161

6262
test_svd_impl!(test, 3, 3);
63-
test_svd_impl!(test_u, 3, 3);
64-
test_svd_impl!(test_vt, 3, 3);
63+
test_svd_impl!(test_no_vt, 3, 3);
64+
test_svd_impl!(test_no_u, 3, 3);
6565
test_svd_impl!(test, 4, 3);
66-
test_svd_impl!(test_u, 4, 3);
67-
test_svd_impl!(test_vt, 4, 3);
66+
test_svd_impl!(test_no_vt, 4, 3);
67+
test_svd_impl!(test_no_u, 4, 3);
6868
test_svd_impl!(test, 3, 4);
69-
test_svd_impl!(test_u, 3, 4);
70-
test_svd_impl!(test_vt, 3, 4);
69+
test_svd_impl!(test_no_vt, 3, 4);
70+
test_svd_impl!(test_no_u, 3, 4);

0 commit comments

Comments
 (0)