diff --git a/src/lib.rs b/src/lib.rs index 0aaf6527..429183dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,7 @@ pub mod generate; pub mod inner; pub mod lapack; pub mod layout; +pub mod mgs; pub mod norm; pub mod operator; pub mod opnorm; diff --git a/src/mgs.rs b/src/mgs.rs new file mode 100644 index 00000000..9e25385c --- /dev/null +++ b/src/mgs.rs @@ -0,0 +1,187 @@ +//! Modified Gram-Schmit orthogonalizer + +use crate::{generate::*, inner::*, norm::Norm, types::*}; +use ndarray::*; + +/// Iterative orthogonalizer using modified Gram-Schmit procedure +#[derive(Debug, Clone)] +pub struct MGS { + /// Dimension of base space + dimension: usize, + /// Basis of spanned space + q: Vec>, +} + +/// Q-matrix +/// +/// - Maybe **NOT** square +/// - Unitary for existing columns +/// +pub type Q = Array2; + +/// R-matrix +/// +/// - Maybe **NOT** square +/// - Upper triangle +/// +pub type R = Array2; + +impl MGS { + /// Create an empty orthogonalizer + pub fn new(dimension: usize) -> Self { + Self { + dimension, + q: Vec::new(), + } + } + + /// Dimension of input array + pub fn dim(&self) -> usize { + self.dimension + } + + /// Number of cached basis + /// + /// ```rust + /// # use ndarray::*; + /// # use ndarray_linalg::{mgs::*, *}; + /// const N: usize = 3; + /// let mut mgs = MGS::::new(N); + /// assert_eq!(mgs.dim(), N); + /// assert_eq!(mgs.len(), 0); + /// + /// mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap(); + /// assert_eq!(mgs.len(), 1); + /// ``` + pub fn len(&self) -> usize { + self.q.len() + } + + /// Orthogonalize given vector using current basis + /// + /// Panic + /// ------- + /// - if the size of the input array mismatches to the dimension + /// + pub fn orthogonalize(&self, a: &mut ArrayBase) -> Array1 + where + A: Lapack, + S: DataMut, + { + assert_eq!(a.len(), self.dim()); + let mut coef = Array1::zeros(self.len() + 1); + for i in 0..self.len() { + let q = &self.q[i]; + let c = q.inner(&a); + azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } ); + coef[i] = c; + } + let nrm = a.norm_l2(); + coef[self.len()] = A::from_real(nrm); + coef + } + + /// Add new vector if the residual is larger than relative tolerance + /// + /// ```rust + /// # use ndarray::*; + /// # use ndarray_linalg::{mgs::*, *}; + /// let mut mgs = MGS::new(3); + /// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap(); + /// close_l2(&coef, &array![1.0], 1e-9); + /// + /// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap(); + /// close_l2(&coef, &array![1.0, 1.0], 1e-9); + /// + /// // Fail if the vector is linearly dependent + /// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err()); + /// + /// // You can get coefficients of dependent vector + /// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) { + /// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9); + /// } + /// ``` + /// + /// Panic + /// ------- + /// - if the size of the input array mismatches to the dimension + /// + pub fn append(&mut self, a: ArrayBase, rtol: A::Real) -> Result, Array1> + where + A: Lapack, + S: Data, + { + let mut a = a.into_owned(); + let coef = self.orthogonalize(&mut a); + let nrm = coef[coef.len() - 1].re(); + if nrm < rtol { + // Linearly dependent + return Err(coef); + } + azip!(mut a in { *a = *a / A::from_real(nrm) }); + self.q.push(a); + Ok(coef) + } + + /// Get orthogonal basis as Q matrix + pub fn get_q(&self) -> Q { + hstack(&self.q).unwrap() + } +} + +/// Strategy for linearly dependent vectors appearing in iterative QR decomposition +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Strategy { + /// Terminate iteration if dependent vector comes + Terminate, + + /// Skip dependent vector + Skip, + + /// Orthogonalize dependent vector without adding to Q, + /// i.e. R must be non-square like following: + /// + /// ```text + /// x x x x x + /// 0 x x x x + /// 0 0 0 x x + /// 0 0 0 0 x + /// ``` + Full, +} + +/// Online QR decomposition of vectors using modified Gram-Schmit algorithm +pub fn mgs( + iter: impl Iterator>, + dim: usize, + rtol: A::Real, + strategy: Strategy, +) -> (Q, R) +where + A: Scalar + Lapack, + S: Data, +{ + let mut ortho = MGS::new(dim); + let mut coefs = Vec::new(); + for a in iter { + match ortho.append(a, rtol) { + Ok(coef) => coefs.push(coef), + Err(coef) => match strategy { + Strategy::Terminate => break, + Strategy::Skip => continue, + Strategy::Full => coefs.push(coef), + }, + } + } + let n = ortho.len(); + let m = coefs.len(); + let mut r = Array2::zeros((n, m).f()); + for j in 0..m { + for i in 0..n { + if i < coefs[j].len() { + r[(i, j)] = coefs[j][i]; + } + } + } + (ortho.get_q(), r) +} diff --git a/tests/mgs.rs b/tests/mgs.rs new file mode 100644 index 00000000..2212d98f --- /dev/null +++ b/tests/mgs.rs @@ -0,0 +1,83 @@ +use ndarray::*; +use ndarray_linalg::{mgs::*, *}; + +fn qr_full() { + const N: usize = 5; + let rtol: A::Real = A::real(1e-9); + + let a: Array2 = random((N, N)); + let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); + assert_close_l2!(&q.dot(&r), &a, rtol); + + let qc: Array2 = conjugate(&q); + assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol); +} + +#[test] +fn qr_full_real() { + qr_full::(); +} + +#[test] +fn qr_full_complex() { + qr_full::(); +} + +fn qr() { + const N: usize = 4; + let rtol: A::Real = A::real(1e-9); + + let a: Array2 = random((N, N / 2)); + let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); + assert_close_l2!(&q.dot(&r), &a, rtol); + + let qc: Array2 = conjugate(&q); + assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol); +} + +#[test] +fn qr_real() { + qr::(); +} + +#[test] +fn qr_complex() { + qr::(); +} + +fn qr_over() { + const N: usize = 4; + let rtol: A::Real = A::real(1e-9); + + let a: Array2 = random((N, N * 2)); + + // Terminate + let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate); + let a_sub = a.slice(s![.., 0..N]); + assert_close_l2!(&q.dot(&r), &a_sub, rtol); + let qc: Array2 = conjugate(&q); + assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol); + + // Skip + let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip); + let a_sub = a.slice(s![.., 0..N]); + assert_close_l2!(&q.dot(&r), &a_sub, rtol); + let qc: Array2 = conjugate(&q); + assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol); + + // Full + let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Full); + assert_close_l2!(&q.dot(&r), &a, rtol); + let qc: Array2 = conjugate(&q); + assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol); +} + +#[test] +fn qr_over_real() { + qr_over::(); +} + +#[test] +fn qr_over_complex() { + qr_over::(); +}