Skip to content

Commit e59ede9

Browse files
authored
Merge pull request #13 from termoshtt/qr
QR decomposition
2 parents 29fe98a + bc556cf commit e59ede9

File tree

6 files changed

+353
-4
lines changed

6 files changed

+353
-4
lines changed

src/hermite.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use matrix::Matrix;
88
use square::SquareMatrix;
99
use error::LinalgError;
1010
use eigh::ImplEigh;
11+
use qr::ImplQR;
1112
use svd::ImplSVD;
1213
use norm::ImplNorm;
1314
use solve::ImplSolve;
@@ -21,7 +22,7 @@ pub trait HermiteMatrix: SquareMatrix + Matrix {
2122
}
2223

2324
impl<A> HermiteMatrix for Array<A, (Ix, Ix)>
24-
where A: ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float
25+
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float
2526
{
2627
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
2728
try!(self.check_square());

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod matrix;
99
pub mod square;
1010
pub mod hermite;
1111

12+
pub mod qr;
1213
pub mod svd;
1314
pub mod eigh;
1415
pub mod norm;

src/matrix.rs

+39-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Define trait for general matrix
22
3+
use std::cmp::min;
34
use ndarray::prelude::*;
45
use ndarray::LinalgScalar;
56

67
use error::LapackError;
8+
use qr::ImplQR;
79
use svd::ImplSVD;
810
use norm::ImplNorm;
911

@@ -21,11 +23,11 @@ pub trait Matrix: Sized {
2123
fn norm_f(&self) -> Self::Scalar;
2224
/// singular-value decomposition (SVD)
2325
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError>;
24-
// fn qr(self) -> (Self, Self);
26+
fn qr(self) -> Result<(Self, Self), LapackError>;
2527
}
2628

2729
impl<A> Matrix for Array<A, (Ix, Ix)>
28-
where A: ImplSVD + ImplNorm + LinalgScalar
30+
where A: ImplQR + ImplSVD + ImplNorm + LinalgScalar
2931
{
3032
type Scalar = A;
3133
type Vector = Array<A, Ix>;
@@ -74,4 +76,39 @@ impl<A> Matrix for Array<A, (Ix, Ix)>
7476
Ok((ua, sv, va))
7577
}
7678
}
79+
fn qr(self) -> Result<(Self, Self), LapackError> {
80+
let (n, m) = self.size();
81+
let strides = self.strides();
82+
let k = min(n, m);
83+
let (q, r) = if strides[0] < strides[1] {
84+
try!(ImplQR::qr(m, n, self.clone().into_raw_vec()))
85+
} else {
86+
try!(ImplQR::lq(n, m, self.clone().into_raw_vec()))
87+
};
88+
let (qa, ra) = if strides[0] < strides[1] {
89+
(Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
90+
Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())
91+
} else {
92+
(Array::from_vec(q).into_shape((n, m)).unwrap(),
93+
Array::from_vec(r).into_shape((n, m)).unwrap())
94+
};
95+
let qm = if m > k {
96+
let (qsl, _) = qa.view().split_at(Axis(1), k);
97+
qsl.to_owned()
98+
} else {
99+
qa
100+
};
101+
let mut rm = if n > k {
102+
let (rsl, _) = ra.view().split_at(Axis(0), k);
103+
rsl.to_owned()
104+
} else {
105+
ra
106+
};
107+
for ((i, j), val) in rm.indexed_iter_mut() {
108+
if i > j {
109+
*val = A::zero();
110+
}
111+
}
112+
Ok((qm, rm))
113+
}
77114
}

src/qr.rs

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//! Implement QR decomposition
2+
3+
extern crate lapack;
4+
5+
use std::cmp::min;
6+
use self::lapack::fortran::*;
7+
use num_traits::Zero;
8+
9+
use error::LapackError;
10+
11+
pub trait ImplQR: Sized {
12+
fn qr(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
13+
fn lq(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
14+
}
15+
16+
macro_rules! impl_qr {
17+
($geqrf:path, $orgqr:path, $gelqf:path, $orglq:path) => {
18+
// XXX These codes are most same, but the argument of $orgqr and $orglq are different!
19+
fn qr(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
20+
let n = n as i32;
21+
let m = m as i32;
22+
let mut info = 0;
23+
let k = min(m, n);
24+
let lda = m;
25+
let lw_default = 1000;
26+
let mut tau = vec![Self::zero(); k as usize];
27+
let mut work = vec![Self::zero(); lw_default];
28+
// estimate lwork
29+
$geqrf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info);
30+
let lwork_r = work[0] as i32;
31+
if lwork_r > lw_default as i32 {
32+
work = vec![Self::zero(); lwork_r as usize];
33+
}
34+
// calc R
35+
$geqrf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info);
36+
if info != 0 {
37+
return Err(From::from(info));
38+
}
39+
let r = a.clone();
40+
// re-estimate lwork
41+
$orgqr(m, k, k, &mut a, lda, &mut tau, &mut work, -1, &mut info);
42+
let lwork_q = work[0] as i32;
43+
if lwork_q > lwork_r {
44+
work = vec![Self::zero(); lwork_q as usize];
45+
}
46+
// calc Q
47+
$orgqr(m,
48+
k,
49+
k,
50+
&mut a,
51+
lda,
52+
&mut tau,
53+
&mut work,
54+
lwork_q,
55+
&mut info);
56+
if info == 0 {
57+
Ok((a, r))
58+
} else {
59+
Err(From::from(info))
60+
}
61+
}
62+
fn lq(n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
63+
let n = n as i32;
64+
let m = m as i32;
65+
let mut info = 0;
66+
let k = min(m, n);
67+
let lda = m;
68+
let lw_default = 1000;
69+
let mut tau = vec![Self::zero(); k as usize];
70+
let mut work = vec![Self::zero(); lw_default];
71+
// estimate lwork
72+
$gelqf(m, n, &mut a, lda, &mut tau, &mut work, -1, &mut info);
73+
let lwork_r = work[0] as i32;
74+
if lwork_r > lw_default as i32 {
75+
work = vec![Self::zero(); lwork_r as usize];
76+
}
77+
// calc R
78+
$gelqf(m, n, &mut a, lda, &mut tau, &mut work, lwork_r, &mut info);
79+
if info != 0 {
80+
return Err(From::from(info));
81+
}
82+
let r = a.clone();
83+
// re-estimate lwork
84+
$orglq(k, n, k, &mut a, lda, &mut tau, &mut work, -1, &mut info);
85+
let lwork_q = work[0] as i32;
86+
if lwork_q > lwork_r {
87+
work = vec![Self::zero(); lwork_q as usize];
88+
}
89+
// calc Q
90+
$orglq(k,
91+
n,
92+
k,
93+
&mut a,
94+
lda,
95+
&mut tau,
96+
&mut work,
97+
lwork_q,
98+
&mut info);
99+
if info == 0 {
100+
Ok((a, r))
101+
} else {
102+
Err(From::from(info))
103+
}
104+
}
105+
}} // endmacro
106+
107+
impl ImplQR for f64 {
108+
impl_qr!(dgeqrf, dorgqr, dgelqf, dorglq);
109+
}
110+
111+
impl ImplQR for f32 {
112+
impl_qr!(sgeqrf, sorgqr, sgelqf, sorglq);
113+
}

src/square.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use num_traits::float::Float;
66

77
use matrix::Matrix;
88
use error::{LinalgError, NotSquareError};
9+
use qr::ImplQR;
910
use svd::ImplSVD;
1011
use norm::ImplNorm;
1112
use solve::ImplSolve;
@@ -37,7 +38,7 @@ pub trait SquareMatrix: Matrix {
3738
}
3839

3940
impl<A> SquareMatrix for Array<A, (Ix, Ix)>
40-
where A: ImplSVD + ImplNorm + ImplSolve + LinalgScalar + Float
41+
where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float
4142
{
4243
fn inv(self) -> Result<Self, LinalgError> {
4344
try!(self.check_square());

0 commit comments

Comments
 (0)