Skip to content

Commit afe38d4

Browse files
authored
Merge pull request #163 from rust-ndarray/bugfix_svd_162
Fix SVD without u or vt
2 parents cf68301 + 4d06adb commit afe38d4

File tree

4 files changed

+66
-33
lines changed

4 files changed

+66
-33
lines changed

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ version = "0.6"
4646
default-features = false
4747
features = ["static"]
4848
optional = true
49+
50+
[dev-dependencies]
51+
paste = "0.1"

src/lapack/svd.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,16 @@ macro_rules! impl_svd {
4242
let (ju, ldu, mut u) = if calc_u {
4343
(FlagSVD::All, m, vec![Self::zero(); (m * m) as usize])
4444
} else {
45-
(FlagSVD::No, 0, Vec::new())
45+
(FlagSVD::No, 1, Vec::new())
4646
};
4747
let (jvt, ldvt, mut vt) = if calc_vt {
4848
(FlagSVD::All, n, vec![Self::zero(); (n * n) as usize])
4949
} else {
50-
(FlagSVD::No, 0, Vec::new())
50+
(FlagSVD::No, n, Vec::new())
5151
};
5252
let mut s = vec![Self::Real::zero(); k as usize];
5353
let mut superb = vec![Self::Real::zero(); (k - 1) as usize];
54+
dbg!(ldvt);
5455
let info = $gesvd(
5556
l.lapacke_layout(),
5657
ju as u8,
@@ -70,8 +71,8 @@ macro_rules! impl_svd {
7071
info,
7172
SVDOutput {
7273
s: s,
73-
u: if ldu > 0 { Some(u) } else { None },
74-
vt: if ldvt > 0 { Some(vt) } else { None },
74+
u: if calc_u { Some(u) } else { None },
75+
vt: if calc_vt { Some(vt) } else { None },
7576
},
7677
)
7778
}

src/svd.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ where
7575
let l = self.layout()?;
7676
let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? };
7777
let (n, m) = l.size();
78-
let u = svd_res.u.map(|u| into_matrix(l.resized(n, n), u).unwrap());
79-
let vt = svd_res.vt.map(|vt| into_matrix(l.resized(m, m), vt).unwrap());
78+
let u = svd_res
79+
.u
80+
.map(|u| into_matrix(l.resized(n, n), u).expect("Size of U mismatches"));
81+
let vt = svd_res
82+
.vt
83+
.map(|vt| into_matrix(l.resized(m, m), vt).expect("Size of VT mismatches"));
8084
let s = ArrayBase::from_vec(svd_res.s);
8185
Ok((u, s, vt))
8286
}

tests/svd.rs

+52-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use ndarray::*;
22
use ndarray_linalg::*;
33
use std::cmp::min;
44

5-
fn test(a: &Array2<f64>, n: usize, m: usize) {
5+
fn test(a: &Array2<f64>) {
6+
let (n, m) = a.dim();
67
let answer = a.clone();
78
println!("a = \n{:?}", a);
89
let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
@@ -18,38 +19,62 @@ fn test(a: &Array2<f64>, n: usize, m: usize) {
1819
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
1920
}
2021

21-
#[test]
22-
fn svd_square() {
23-
let a = random((3, 3));
24-
test(&a, 3, 3);
22+
fn test_no_vt(a: &Array2<f64>) {
23+
let (n, _m) = a.dim();
24+
println!("a = \n{:?}", a);
25+
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
26+
assert!(u.is_some());
27+
assert!(vt.is_none());
28+
let u = u.unwrap();
29+
assert_eq!(u.dim().0, n);
30+
assert_eq!(u.dim().1, n);
2531
}
2632

27-
#[test]
28-
fn svd_square_t() {
29-
let a = random((3, 3).f());
30-
test(&a, 3, 3);
33+
fn test_no_u(a: &Array2<f64>) {
34+
let (_n, m) = a.dim();
35+
println!("a = \n{:?}", a);
36+
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
37+
assert!(u.is_none());
38+
assert!(vt.is_some());
39+
let vt = vt.unwrap();
40+
assert_eq!(vt.dim().0, m);
41+
assert_eq!(vt.dim().1, m);
3142
}
3243

33-
#[test]
34-
fn svd_3x4() {
35-
let a = random((3, 4));
36-
test(&a, 3, 4);
44+
fn test_diag_only(a: &Array2<f64>) {
45+
println!("a = \n{:?}", a);
46+
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap();
47+
assert!(u.is_none());
48+
assert!(vt.is_none());
3749
}
3850

39-
#[test]
40-
fn svd_3x4_t() {
41-
let a = random((3, 4).f());
42-
test(&a, 3, 4);
43-
}
51+
macro_rules! test_svd_impl {
52+
($test:ident, $n:expr, $m:expr) => {
53+
paste::item! {
54+
#[test]
55+
fn [<svd_ $test _ $n x $m>]() {
56+
let a = random(($n, $m));
57+
$test(&a);
58+
}
4459

45-
#[test]
46-
fn svd_4x3() {
47-
let a = random((4, 3));
48-
test(&a, 4, 3);
60+
#[test]
61+
fn [<svd_ $test _ $n x $m _t>]() {
62+
let a = random(($n, $m).f());
63+
$test(&a);
64+
}
65+
}
66+
};
4967
}
5068

51-
#[test]
52-
fn svd_4x3_t() {
53-
let a = random((4, 3).f());
54-
test(&a, 4, 3);
55-
}
69+
test_svd_impl!(test, 3, 3);
70+
test_svd_impl!(test_no_vt, 3, 3);
71+
test_svd_impl!(test_no_u, 3, 3);
72+
test_svd_impl!(test_diag_only, 3, 3);
73+
test_svd_impl!(test, 4, 3);
74+
test_svd_impl!(test_no_vt, 4, 3);
75+
test_svd_impl!(test_no_u, 4, 3);
76+
test_svd_impl!(test_diag_only, 4, 3);
77+
test_svd_impl!(test, 3, 4);
78+
test_svd_impl!(test_no_vt, 3, 4);
79+
test_svd_impl!(test_no_u, 3, 4);
80+
test_svd_impl!(test_diag_only, 3, 4);

0 commit comments

Comments
 (0)