Skip to content

Commit a0bd205

Browse files
committed
Change sorting traits to use Order enum
1 parent 659ecc1 commit a0bd205

File tree

6 files changed

+28
-40
lines changed

6 files changed

+28
-40
lines changed

src/eigh.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
use ndarray::{s, Array1, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};
44

55
use crate::{
6-
check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Result,
6+
check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Order,
7+
Result,
78
};
89

910
fn symmetric_eig<A: NdFloat, S: DataMut<Elem = A>>(
@@ -272,44 +273,42 @@ impl<A: NdFloat, S: Data<Elem = A>> EigValsh for ArrayBase<S, Ix2> {
272273
///
273274
/// Will panic if shape or layout of inputs differ from eigen output, or if input contains NaN.
274275
pub trait EigSort: Sized {
275-
fn sort_eig(self, descending: bool) -> Self;
276+
fn sort_eig(self, order: Order) -> Self;
276277

277278
/// Sort eigendecomposition by the eigenvalues in ascending order
278279
fn sort_eig_asc(self) -> Self {
279-
self.sort_eig(false)
280+
self.sort_eig(Order::Smallest)
280281
}
281282

282283
/// Sort eigendecomposition by the eigenvalues in descending order
283284
fn sort_eig_desc(self) -> Self {
284-
self.sort_eig(true)
285+
self.sort_eig(Order::Largest)
285286
}
286287
}
287288

288289
/// Implementation on output of `EigValsh` traits
289290
impl<A: NdFloat> EigSort for Array1<A> {
290-
fn sort_eig(mut self, descending: bool) -> Self {
291+
fn sort_eig(mut self, order: Order) -> Self {
291292
// Panics on non-standard layouts, which is fine because our eigenvals have standard layout
292293
let slice = self.as_slice_mut().unwrap();
293294
// Panic only happens with NaN values
294-
if descending {
295-
slice.sort_by(|a, b| b.partial_cmp(a).unwrap());
296-
} else {
297-
slice.sort_by(|a, b| a.partial_cmp(b).unwrap());
295+
match order {
296+
Order::Largest => slice.sort_by(|a, b| b.partial_cmp(a).unwrap()),
297+
Order::Smallest => slice.sort_by(|a, b| a.partial_cmp(b).unwrap()),
298298
}
299299
self
300300
}
301301
}
302302

303303
/// Implementation on output of `Eigh` traits
304304
impl<A: NdFloat> EigSort for (Array1<A>, Array2<A>) {
305-
fn sort_eig(self, descending: bool) -> Self {
305+
fn sort_eig(self, order: Order) -> Self {
306306
let (mut vals, vecs) = self;
307307
let mut value_idx: Vec<_> = vals.iter().copied().enumerate().collect();
308308
// Panic only happens with NaN values
309-
if descending {
310-
value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
311-
} else {
312-
value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
309+
match order {
310+
Order::Largest => value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()),
311+
Order::Smallest => value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()),
313312
}
314313

315314
let mut out = Array2::zeros(vecs.dim());

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ pub(crate) fn check_square<S: RawData>(arr: &ArrayBase<S, Ix2>) -> Result<usize>
7070
}
7171

7272
/// Find largest or smallest eigenvalues
73+
///
74+
/// Corresponds to descending and ascending order
7375
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
7476
pub enum Order {
7577
Largest,

src/lobpcg/algorithm.rs

+2-11
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,16 @@ fn sorted_eig<A: NdFloat>(
3131
size: usize,
3232
order: Order,
3333
) -> Result<(Array1<A>, Array2<A>)> {
34-
let n = a.len_of(Axis(0));
35-
3634
let res = match b {
3735
Some(b) => generalized_eig(a, b)?,
3836
_ => a.eigh_into()?,
3937
};
4038

4139
// sort and ensure that signs are deterministic
42-
let (vals, vecs) = res.sort_eig(false);
40+
let (vals, vecs) = res.sort_eig(order);
4341
let s = vecs.row(0).mapv(|x| x.signum());
4442
let vecs = vecs * s;
45-
46-
Ok(match order {
47-
Order::Largest => (
48-
vals.slice_move(s![n-size..; -1]),
49-
vecs.slice_move(s![.., n-size..; -1]),
50-
),
51-
Order::Smallest => (vals.slice_move(s![..size]), vecs.slice_move(s![.., ..size])),
52-
})
43+
Ok((vals.slice_move(s![..size]), vecs.slice_move(s![.., ..size])))
5344
}
5445

5546
/// Masks a matrix with the given `matrix`

src/svd.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use ndarray::{s, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix2, NdFloat};
88

99
use crate::{
1010
bidiagonal::Bidiagonal, eigh::wilkinson_shift, givens::GivensRotation, index::*, LinalgError,
11-
Result,
11+
Order, Result,
1212
};
1313

1414
fn svd<A: NdFloat, S: DataMut<Elem = A>>(
@@ -482,29 +482,28 @@ impl<A: NdFloat, S: Data<Elem = A>> SVD for ArrayBase<S, Ix2> {
482482
///
483483
/// Will panic if shape of inputs differs from shape of SVD output, or if input contains NaN.
484484
pub trait SvdSort: Sized {
485-
fn sort_svd(self, descending: bool) -> Self;
485+
fn sort_svd(self, order: Order) -> Self;
486486

487487
/// Sort SVD decomposition by the singular values in ascending order
488488
fn sort_svd_asc(self) -> Self {
489-
self.sort_svd(false)
489+
self.sort_svd(Order::Smallest)
490490
}
491491

492492
/// Sort SVD decomposition by the singular values in descending order
493493
fn sort_svd_desc(self) -> Self {
494-
self.sort_svd(true)
494+
self.sort_svd(Order::Largest)
495495
}
496496
}
497497

498498
/// Implemented on the output of the `SVD` traits
499499
impl<A: NdFloat> SvdSort for (Option<Array2<A>>, Array1<A>, Option<Array2<A>>) {
500-
fn sort_svd(self, descending: bool) -> Self {
500+
fn sort_svd(self, order: Order) -> Self {
501501
let (u, mut s, vt) = self;
502502
let mut value_idx: Vec<_> = s.iter().copied().enumerate().collect();
503503
// Panic only happens with NaN values
504-
if descending {
505-
value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
506-
} else {
507-
value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
504+
match order {
505+
Order::Largest => value_idx.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()),
506+
Order::Smallest => value_idx.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()),
508507
}
509508

510509
let apply_ordering = |arr: &Array2<A>, ax, values_idx: &Vec<_>| {

tests/lobpcg.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ fn test_marchenko_pastur() {
6767
}
6868

6969
fn run_lobpcg_eig_test(arr: Array2<f64>, num: usize, ordering: Order) {
70-
let (eigvals, _) = arr.eigh().unwrap().sort_eig(ordering == Order::Largest);
70+
let (eigvals, _) = arr.eigh().unwrap().sort_eig(ordering);
7171
let res = TruncatedEig::new_with_rng(arr.clone(), ordering, Xoshiro256Plus::seed_from_u64(42))
7272
.precision(1e-3)
7373
.decompose(num)
@@ -108,10 +108,7 @@ fn problematic_eig_matrix() {
108108
}
109109

110110
fn run_lobpcg_svd_test(arr: Array2<f64>, ordering: Order) {
111-
let (_, s, _) = arr
112-
.svd(false, false)
113-
.unwrap()
114-
.sort_svd(ordering == Order::Largest);
111+
let (_, s, _) = arr.svd(false, false).unwrap().sort_svd(ordering);
115112
let (u, ts, vt) =
116113
TruncatedSvd::new_with_rng(arr.clone(), ordering, Xoshiro256Plus::seed_from_u64(42))
117114
.precision(1e-3)

tests/svd.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn run_svd_test(arr: Array2<f64>) {
6161
proptest! {
6262
#![proptest_config(ProptestConfig::with_cases(1000))]
6363
#[test]
64-
fn bidiagonal_test(arr in common::rect_arr()) {
64+
fn svd_test(arr in common::rect_arr()) {
6565
run_svd_test(arr);
6666
}
6767
}

0 commit comments

Comments
 (0)