Skip to content

Commit 7685117

Browse files
authored
Merge pull request #164 from nlhepler/master
Implement SVD by divide-and-conquer
2 parents afe38d4 + 3e23773 commit 7685117

File tree

5 files changed

+258
-1
lines changed

5 files changed

+258
-1
lines changed

src/lapack/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod qr;
77
pub mod solve;
88
pub mod solveh;
99
pub mod svd;
10+
pub mod svddc;
1011
pub mod triangular;
1112

1213
pub use self::cholesky::*;
@@ -16,6 +17,7 @@ pub use self::qr::*;
1617
pub use self::solve::*;
1718
pub use self::solveh::*;
1819
pub use self::svd::*;
20+
pub use self::svddc::*;
1921
pub use self::triangular::*;
2022

2123
use super::error::*;
@@ -24,7 +26,7 @@ use super::types::*;
2426
pub type Pivot = Vec<i32>;
2527

2628
/// Trait for primitive types which implements LAPACK subroutines
27-
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {}
29+
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {}
2830

2931
impl Lapack for f32 {}
3032
impl Lapack for f64 {}

src/lapack/svddc.rs

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use lapacke;
2+
use num_traits::Zero;
3+
4+
use crate::error::*;
5+
use crate::layout::MatrixLayout;
6+
use crate::types::*;
7+
use crate::svddc::UVTFlag;
8+
9+
use super::{SVDOutput, into_result};
10+
11+
pub trait SVDDC_: Scalar {
12+
unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
13+
}
14+
15+
macro_rules! impl_svdd {
16+
($scalar:ty, $gesdd:path) => {
17+
impl SVDDC_ for $scalar {
18+
unsafe fn svddc(
19+
l: MatrixLayout,
20+
jobz: UVTFlag,
21+
mut a: &mut [Self],
22+
) -> Result<SVDOutput<Self>> {
23+
let (m, n) = l.size();
24+
let k = m.min(n);
25+
let lda = l.lda();
26+
let (ucol, vtrow) = match jobz {
27+
UVTFlag::Full => (m, n),
28+
UVTFlag::Some => (k, k),
29+
UVTFlag::None => (1, 1),
30+
};
31+
let mut s = vec![Self::Real::zero(); k.max(1) as usize];
32+
let mut u = vec![Self::zero(); (m * ucol).max(1) as usize];
33+
let ldu = l.resized(m, ucol).lda();
34+
let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize];
35+
let ldvt = l.resized(vtrow, n).lda();
36+
let info = $gesdd(
37+
l.lapacke_layout(),
38+
jobz as u8,
39+
m,
40+
n,
41+
&mut a,
42+
lda,
43+
&mut s,
44+
&mut u,
45+
ldu,
46+
&mut vt,
47+
ldvt,
48+
);
49+
into_result(
50+
info,
51+
SVDOutput {
52+
s: s,
53+
u: if jobz == UVTFlag::None { None } else { Some(u) },
54+
vt: if jobz == UVTFlag::None {
55+
None
56+
} else {
57+
Some(vt)
58+
},
59+
},
60+
)
61+
}
62+
}
63+
};
64+
}
65+
66+
impl_svdd!(f32, lapacke::sgesdd);
67+
impl_svdd!(f64, lapacke::dgesdd);
68+
impl_svdd!(c32, lapacke::cgesdd);
69+
impl_svdd!(c64, lapacke::zgesdd);

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pub mod qr;
5757
pub mod solve;
5858
pub mod solveh;
5959
pub mod svd;
60+
pub mod svddc;
6061
pub mod trace;
6162
pub mod triangular;
6263
pub mod types;
@@ -76,6 +77,7 @@ pub use qr::*;
7677
pub use solve::*;
7778
pub use solveh::*;
7879
pub use svd::*;
80+
pub use svddc::*;
7981
pub use trace::*;
8082
pub use triangular::*;
8183
pub use types::*;

src/svddc.rs

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd)
2+
3+
use ndarray::*;
4+
5+
use super::convert::*;
6+
use super::error::*;
7+
use super::layout::*;
8+
use super::types::*;
9+
10+
#[derive(Clone, Copy, Eq, PartialEq)]
11+
#[repr(u8)]
12+
pub enum UVTFlag {
13+
Full = b'A',
14+
Some = b'S',
15+
None = b'N',
16+
}
17+
18+
/// Singular-value decomposition of matrix (copying) by divide-and-conquer
19+
pub trait SVDDC {
20+
type U;
21+
type VT;
22+
type Sigma;
23+
fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
24+
}
25+
26+
/// Singular-value decomposition of matrix by divide-and-conquer
27+
pub trait SVDDCInto {
28+
type U;
29+
type VT;
30+
type Sigma;
31+
fn svddc_into(
32+
self,
33+
uvt_flag: UVTFlag,
34+
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
35+
}
36+
37+
/// Singular-value decomposition of matrix reference by divide-and-conquer
38+
pub trait SVDDCInplace {
39+
type U;
40+
type VT;
41+
type Sigma;
42+
fn svddc_inplace(
43+
&mut self,
44+
uvt_flag: UVTFlag,
45+
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
46+
}
47+
48+
impl<A, S> SVDDC for ArrayBase<S, Ix2>
49+
where
50+
A: Scalar + Lapack,
51+
S: DataMut<Elem = A>,
52+
{
53+
type U = Array2<A>;
54+
type VT = Array2<A>;
55+
type Sigma = Array1<A::Real>;
56+
57+
fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
58+
self.to_owned().svddc_into(uvt_flag)
59+
}
60+
}
61+
62+
impl<A, S> SVDDCInto for ArrayBase<S, Ix2>
63+
where
64+
A: Scalar + Lapack,
65+
S: DataMut<Elem = A>,
66+
{
67+
type U = Array2<A>;
68+
type VT = Array2<A>;
69+
type Sigma = Array1<A::Real>;
70+
71+
fn svddc_into(
72+
mut self,
73+
uvt_flag: UVTFlag,
74+
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
75+
self.svddc_inplace(uvt_flag)
76+
}
77+
}
78+
79+
impl<A, S> SVDDCInplace for ArrayBase<S, Ix2>
80+
where
81+
A: Scalar + Lapack,
82+
S: DataMut<Elem = A>,
83+
{
84+
type U = Array2<A>;
85+
type VT = Array2<A>;
86+
type Sigma = Array1<A::Real>;
87+
88+
fn svddc_inplace(
89+
&mut self,
90+
uvt_flag: UVTFlag,
91+
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
92+
let l = self.layout()?;
93+
let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? };
94+
let (m, n) = l.size();
95+
let k = m.min(n);
96+
let (ldu, tdu, ldvt, tdvt) = match uvt_flag {
97+
UVTFlag::Full => (m, m, n, n),
98+
UVTFlag::Some => (m, k, k, n),
99+
UVTFlag::None => (1, 1, 1, 1),
100+
};
101+
let u = svd_res
102+
.u
103+
.map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches"));
104+
let vt = svd_res
105+
.vt
106+
.map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches"));
107+
let s = ArrayBase::from_vec(svd_res.s);
108+
Ok((u, s, vt))
109+
}
110+
}

tests/svddc.rs

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
fn test(a: &Array2<f64>, flag: UVTFlag) {
5+
let (n, m) = a.dim();
6+
let k = n.min(m);
7+
let answer = a.clone();
8+
println!("a = \n{:?}", a);
9+
let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap();
10+
let mut sm = match flag {
11+
UVTFlag::Full => Array::zeros((n, m)),
12+
UVTFlag::Some => Array::zeros((k, k)),
13+
UVTFlag::None => {
14+
assert!(u.is_none());
15+
assert!(vt.is_none());
16+
return;
17+
},
18+
};
19+
let u: Array2<_> = u.unwrap();
20+
let vt: Array2<_> = vt.unwrap();
21+
println!("u = \n{:?}", &u);
22+
println!("s = \n{:?}", &s);
23+
println!("v = \n{:?}", &vt);
24+
for i in 0..k {
25+
sm[(i, i)] = s[i];
26+
}
27+
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
28+
}
29+
30+
macro_rules! test_svd_impl {
31+
($n:expr, $m:expr) => {
32+
paste::item! {
33+
#[test]
34+
fn [<svddc_full_ $n x $m>]() {
35+
let a = random(($n, $m));
36+
test(&a, UVTFlag::Full);
37+
}
38+
39+
#[test]
40+
fn [<svddc_some_ $n x $m>]() {
41+
let a = random(($n, $m));
42+
test(&a, UVTFlag::Some);
43+
}
44+
45+
#[test]
46+
fn [<svddc_none_ $n x $m>]() {
47+
let a = random(($n, $m));
48+
test(&a, UVTFlag::None);
49+
}
50+
51+
#[test]
52+
fn [<svddc_full_ $n x $m _t>]() {
53+
let a = random(($n, $m).f());
54+
test(&a, UVTFlag::Full);
55+
}
56+
57+
#[test]
58+
fn [<svddc_some_ $n x $m _t>]() {
59+
let a = random(($n, $m).f());
60+
test(&a, UVTFlag::Some);
61+
}
62+
63+
#[test]
64+
fn [<svddc_none_ $n x $m _t>]() {
65+
let a = random(($n, $m).f());
66+
test(&a, UVTFlag::None);
67+
}
68+
}
69+
};
70+
}
71+
72+
test_svd_impl!(3, 3);
73+
test_svd_impl!(4, 3);
74+
test_svd_impl!(3, 4);

0 commit comments

Comments
 (0)