Skip to content

Commit cf68301

Browse files
authored
Merge pull request #159 from rust-ndarray/linear_operator
Trait for linear operator
2 parents 4f0ef55 + 59133a4 commit cf68301

File tree

5 files changed

+89
-192
lines changed

5 files changed

+89
-192
lines changed

src/diagonal.rs

+10-72
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
use ndarray::*;
44

5-
use super::convert::*;
65
use super::operator::*;
6+
use super::types::*;
77

88
/// Vector as a Diagonal matrix
99
pub struct Diagonal<S: Data> {
@@ -30,81 +30,19 @@ impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
3030
}
3131
}
3232

33-
impl<A, S, Sr> OperatorInplace<Sr, Ix1> for Diagonal<S>
33+
impl<A, Sa> LinearOperator for Diagonal<Sa>
3434
where
35-
A: LinalgScalar,
36-
S: Data<Elem = A>,
37-
Sr: DataMut<Elem = A>,
35+
A: Scalar,
36+
Sa: Data<Elem = A>,
3837
{
39-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix1>) -> &'a mut ArrayBase<Sr, Ix1> {
38+
type Elem = A;
39+
40+
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
41+
where
42+
S: DataMut<Elem = A>,
43+
{
4044
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
4145
*val = *val * *d;
4246
}
43-
a
44-
}
45-
}
46-
47-
impl<A, S, Sr> Operator<A, Sr, Ix1> for Diagonal<S>
48-
where
49-
A: LinalgScalar,
50-
S: Data<Elem = A>,
51-
Sr: Data<Elem = A>,
52-
{
53-
fn op(&self, a: &ArrayBase<Sr, Ix1>) -> Array1<A> {
54-
let mut a = replicate(a);
55-
self.op_inplace(&mut a);
56-
a
57-
}
58-
}
59-
60-
impl<A, S, Sr> OperatorInto<Sr, Ix1> for Diagonal<S>
61-
where
62-
A: LinalgScalar,
63-
S: Data<Elem = A>,
64-
Sr: DataOwned<Elem = A> + DataMut,
65-
{
66-
fn op_into(&self, mut a: ArrayBase<Sr, Ix1>) -> ArrayBase<Sr, Ix1> {
67-
self.op_inplace(&mut a);
68-
a
69-
}
70-
}
71-
72-
impl<A, S, Sr> OperatorInplace<Sr, Ix2> for Diagonal<S>
73-
where
74-
A: LinalgScalar,
75-
S: Data<Elem = A>,
76-
Sr: DataMut<Elem = A>,
77-
{
78-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix2>) -> &'a mut ArrayBase<Sr, Ix2> {
79-
let d = &self.diag;
80-
for ((i, _), val) in a.indexed_iter_mut() {
81-
*val = *val * d[i];
82-
}
83-
a
84-
}
85-
}
86-
87-
impl<A, S, Sr> Operator<A, Sr, Ix2> for Diagonal<S>
88-
where
89-
A: LinalgScalar,
90-
S: Data<Elem = A>,
91-
Sr: Data<Elem = A>,
92-
{
93-
fn op(&self, a: &ArrayBase<Sr, Ix2>) -> Array2<A> {
94-
let mut a = replicate(a);
95-
self.op_inplace(&mut a);
96-
a
97-
}
98-
}
99-
100-
impl<A, S, Sr> OperatorInto<Sr, Ix2> for Diagonal<S>
101-
where
102-
A: LinalgScalar,
103-
S: Data<Elem = A>,
104-
Sr: DataOwned<Elem = A> + DataMut,
105-
{
106-
fn op_into(&self, mut a: ArrayBase<Sr, Ix2>) -> ArrayBase<Sr, Ix2> {
107-
self.op_inplace(&mut a);
108-
a
10947
}
11048
}

src/eigh.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ndarray::*;
55
use crate::diagonal::*;
66
use crate::error::*;
77
use crate::layout::*;
8-
use crate::operator::Operator;
8+
use crate::operator::LinearOperator;
99
use crate::types::*;
1010
use crate::UPLO;
1111

@@ -165,7 +165,7 @@ where
165165
fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output> {
166166
let (e, v) = self.eigh_into(uplo)?;
167167
let e_sqrt = Array1::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt())));
168-
let ev = e_sqrt.into_diagonal().op(&v.t());
169-
Ok(v.op(&ev))
168+
let ev = e_sqrt.into_diagonal().apply2(&v.t());
169+
Ok(v.apply2(&ev))
170170
}
171171
}

src/krylov/arnoldi.rs

+11-29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Arnoldi iteration
22
33
use super::*;
4-
use crate::norm::Norm;
4+
use crate::{norm::Norm, operator::LinearOperator};
55
use num_traits::One;
66
use std::iter::*;
77

@@ -13,7 +13,7 @@ pub struct Arnoldi<A, S, F, Ortho>
1313
where
1414
A: Scalar,
1515
S: DataMut<Elem = A>,
16-
F: Fn(&mut ArrayBase<S, Ix1>),
16+
F: LinearOperator<Elem = A>,
1717
Ortho: Orthogonalizer<Elem = A>,
1818
{
1919
a: F,
@@ -29,7 +29,7 @@ impl<A, S, F, Ortho> Arnoldi<A, S, F, Ortho>
2929
where
3030
A: Scalar + Lapack,
3131
S: DataMut<Elem = A>,
32-
F: Fn(&mut ArrayBase<S, Ix1>),
32+
F: LinearOperator<Elem = A>,
3333
Ortho: Orthogonalizer<Elem = A>,
3434
{
3535
/// Create an Arnoldi iterator from any linear operator `a`
@@ -73,13 +73,13 @@ impl<A, S, F, Ortho> Iterator for Arnoldi<A, S, F, Ortho>
7373
where
7474
A: Scalar + Lapack,
7575
S: DataMut<Elem = A>,
76-
F: Fn(&mut ArrayBase<S, Ix1>),
76+
F: LinearOperator<Elem = A>,
7777
Ortho: Orthogonalizer<Elem = A>,
7878
{
7979
type Item = Array1<A>;
8080

8181
fn next(&mut self) -> Option<Self::Item> {
82-
(self.a)(&mut self.v);
82+
self.a.apply_mut(&mut self.v);
8383
let result = self.ortho.div_append(&mut self.v);
8484
let norm = self.v.norm_l2();
8585
azip!(mut v(&mut self.v) in { *v = v.div_real(norm) });
@@ -96,40 +96,22 @@ where
9696
}
9797
}
9898

99-
/// Interpret a matrix as a linear operator
100-
pub fn mul_mat<A, S1, S2>(a: ArrayBase<S1, Ix2>) -> impl Fn(&mut ArrayBase<S2, Ix1>)
101-
where
102-
A: Scalar,
103-
S1: Data<Elem = A>,
104-
S2: DataMut<Elem = A>,
105-
{
106-
let (n, m) = a.dim();
107-
assert_eq!(n, m, "Input matrix must be square");
108-
move |x| {
109-
assert_eq!(m, x.len(), "Input matrix and vector sizes mismatch");
110-
let ax = a.dot(x);
111-
azip!(mut x(x), ax in { *x = ax });
112-
}
113-
}
114-
11599
/// Utility to execute Arnoldi iteration with Householder reflection
116-
pub fn arnoldi_householder<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
100+
pub fn arnoldi_householder<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
117101
where
118102
A: Scalar + Lapack,
119-
S1: Data<Elem = A>,
120-
S2: DataMut<Elem = A>,
103+
S: DataMut<Elem = A>,
121104
{
122105
let householder = Householder::new(v.len(), tol);
123-
Arnoldi::new(mul_mat(a), v, householder).complete()
106+
Arnoldi::new(a, v, householder).complete()
124107
}
125108

126109
/// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer
127-
pub fn arnoldi_mgs<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
110+
pub fn arnoldi_mgs<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
128111
where
129112
A: Scalar + Lapack,
130-
S1: Data<Elem = A>,
131-
S2: DataMut<Elem = A>,
113+
S: DataMut<Elem = A>,
132114
{
133115
let mgs = MGS::new(v.len(), tol);
134-
Arnoldi::new(mul_mat(a), v, mgs).complete()
116+
Arnoldi::new(a, v, mgs).complete()
135117
}

src/operator.rs

+62-85
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,82 @@
1-
//! Linear Operator
1+
//! Linear operator algebra
22
3+
use crate::generate::hstack;
4+
use crate::types::*;
35
use ndarray::*;
46

5-
use super::types::*;
7+
/// Abstracted linear operator as an action to vector (`ArrayBase<S, Ix1>`) and matrix
8+
/// (`ArrayBase<S, Ix2`)
9+
pub trait LinearOperator {
10+
type Elem: Scalar;
611

7-
pub trait Operator<A, S, D>
8-
where
9-
S: Data<Elem = A>,
10-
D: Dimension,
11-
{
12-
fn op(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
13-
}
14-
15-
pub trait OperatorInto<S, D>
16-
where
17-
S: DataMut,
18-
D: Dimension,
19-
{
20-
fn op_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
21-
}
22-
23-
pub trait OperatorInplace<S, D>
24-
where
25-
S: DataMut,
26-
D: Dimension,
27-
{
28-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
29-
}
12+
/// Apply operator out-place
13+
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<S::Elem>
14+
where
15+
S: Data<Elem = Self::Elem>,
16+
{
17+
let mut a = a.to_owned();
18+
self.apply_mut(&mut a);
19+
a
20+
}
3021

31-
impl<T, A, S, D> Operator<A, S, D> for T
32-
where
33-
A: Scalar + Lapack,
34-
S: Data<Elem = A>,
35-
D: Dimension,
36-
T: linalg::Dot<ArrayBase<S, D>, Output = Array<A, D>>,
37-
{
38-
fn op(&self, rhs: &ArrayBase<S, D>) -> Array<A, D> {
39-
self.dot(rhs)
22+
/// Apply operator in-place
23+
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
24+
where
25+
S: DataMut<Elem = Self::Elem>,
26+
{
27+
let b = self.apply(a);
28+
azip!(mut a(a), b in { *a = b });
4029
}
41-
}
4230

43-
pub trait OperatorMulti<A, S, D>
44-
where
45-
S: Data<Elem = A>,
46-
D: Dimension,
47-
{
48-
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
49-
}
31+
/// Apply operator with move
32+
fn apply_into<S>(&self, mut a: ArrayBase<S, Ix1>) -> ArrayBase<S, Ix1>
33+
where
34+
S: DataOwned<Elem = Self::Elem> + DataMut,
35+
{
36+
self.apply_mut(&mut a);
37+
a
38+
}
5039

51-
impl<T, A, S, D> OperatorMulti<A, S, D> for T
52-
where
53-
A: Scalar + Lapack,
54-
S: DataMut<Elem = A>,
55-
D: Dimension + RemoveAxis,
56-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
57-
{
58-
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D> {
59-
let a = a.to_owned();
60-
self.op_multi_into(a)
40+
/// Apply operator to matrix out-place
41+
fn apply2<S>(&self, a: &ArrayBase<S, Ix2>) -> Array2<S::Elem>
42+
where
43+
S: Data<Elem = Self::Elem>,
44+
{
45+
let cols: Vec<_> = a.axis_iter(Axis(1)).map(|col| self.apply(&col)).collect();
46+
hstack(&cols).unwrap()
6147
}
62-
}
6348

64-
pub trait OperatorMultiInto<S, D>
65-
where
66-
S: DataMut,
67-
D: Dimension,
68-
{
69-
fn op_multi_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
70-
}
49+
/// Apply operator to matrix in-place
50+
fn apply2_mut<S>(&self, a: &mut ArrayBase<S, Ix2>)
51+
where
52+
S: DataMut<Elem = Self::Elem>,
53+
{
54+
for mut col in a.axis_iter_mut(Axis(1)) {
55+
self.apply_mut(&mut col)
56+
}
57+
}
7158

72-
impl<T, A, S, D> OperatorMultiInto<S, D> for T
73-
where
74-
S: DataMut<Elem = A>,
75-
D: Dimension + RemoveAxis,
76-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
77-
{
78-
fn op_multi_into(&self, mut a: ArrayBase<S, D>) -> ArrayBase<S, D> {
79-
self.op_multi_inplace(&mut a);
59+
/// Apply operator to matrix with move
60+
fn apply2_into<S>(&self, mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
61+
where
62+
S: DataOwned<Elem = Self::Elem> + DataMut,
63+
{
64+
self.apply2_mut(&mut a);
8065
a
8166
}
8267
}
8368

84-
pub trait OperatorMultiInplace<S, D>
69+
impl<A, Sa> LinearOperator for ArrayBase<Sa, Ix2>
8570
where
86-
S: DataMut,
87-
D: Dimension,
71+
A: Scalar,
72+
Sa: Data<Elem = A>,
8873
{
89-
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
90-
}
74+
type Elem = A;
9175

92-
impl<T, A, S, D> OperatorMultiInplace<S, D> for T
93-
where
94-
S: DataMut<Elem = A>,
95-
D: Dimension + RemoveAxis,
96-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
97-
{
98-
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
99-
let n = a.ndim();
100-
for mut col in a.axis_iter_mut(Axis(n - 1)) {
101-
self.op_inplace(&mut col);
102-
}
103-
a
76+
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<A>
77+
where
78+
S: Data<Elem = A>,
79+
{
80+
self.dot(a)
10481
}
10582
}

0 commit comments

Comments
 (0)