@@ -2,7 +2,8 @@ use ndarray::*;
2
2
use ndarray_linalg:: * ;
3
3
use std:: cmp:: min;
4
4
5
- fn test ( a : & Array2 < f64 > , n : usize , m : usize ) {
5
+ fn test ( a : & Array2 < f64 > ) {
6
+ let ( n, m) = a. dim ( ) ;
6
7
let answer = a. clone ( ) ;
7
8
println ! ( "a = \n {:?}" , a) ;
8
9
let ( u, s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , true ) . unwrap ( ) ;
@@ -18,38 +19,62 @@ fn test(a: &Array2<f64>, n: usize, m: usize) {
18
19
assert_close_l2 ! ( & u. dot( & sm) . dot( & vt) , & answer, 1e-7 ) ;
19
20
}
20
21
21
- #[ test]
22
- fn svd_square ( ) {
23
- let a = random ( ( 3 , 3 ) ) ;
24
- test ( & a, 3 , 3 ) ;
22
+ fn test_no_vt ( a : & Array2 < f64 > ) {
23
+ let ( n, _m) = a. dim ( ) ;
24
+ println ! ( "a = \n {:?}" , a) ;
25
+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , false ) . unwrap ( ) ;
26
+ assert ! ( u. is_some( ) ) ;
27
+ assert ! ( vt. is_none( ) ) ;
28
+ let u = u. unwrap ( ) ;
29
+ assert_eq ! ( u. dim( ) . 0 , n) ;
30
+ assert_eq ! ( u. dim( ) . 1 , n) ;
25
31
}
26
32
27
- #[ test]
28
- fn svd_square_t ( ) {
29
- let a = random ( ( 3 , 3 ) . f ( ) ) ;
30
- test ( & a, 3 , 3 ) ;
33
+ fn test_no_u ( a : & Array2 < f64 > ) {
34
+ let ( _n, m) = a. dim ( ) ;
35
+ println ! ( "a = \n {:?}" , a) ;
36
+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , true ) . unwrap ( ) ;
37
+ assert ! ( u. is_none( ) ) ;
38
+ assert ! ( vt. is_some( ) ) ;
39
+ let vt = vt. unwrap ( ) ;
40
+ assert_eq ! ( vt. dim( ) . 0 , m) ;
41
+ assert_eq ! ( vt. dim( ) . 1 , m) ;
31
42
}
32
43
33
- #[ test]
34
- fn svd_3x4 ( ) {
35
- let a = random ( ( 3 , 4 ) ) ;
36
- test ( & a, 3 , 4 ) ;
44
+ fn test_diag_only ( a : & Array2 < f64 > ) {
45
+ println ! ( "a = \n {:?}" , a) ;
46
+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , false ) . unwrap ( ) ;
47
+ assert ! ( u. is_none( ) ) ;
48
+ assert ! ( vt. is_none( ) ) ;
37
49
}
38
50
39
- #[ test]
40
- fn svd_3x4_t ( ) {
41
- let a = random ( ( 3 , 4 ) . f ( ) ) ;
42
- test ( & a, 3 , 4 ) ;
43
- }
51
+ macro_rules! test_svd_impl {
52
+ ( $test: ident, $n: expr, $m: expr) => {
53
+ paste:: item! {
54
+ #[ test]
55
+ fn [ <svd_ $test _ $n x $m>] ( ) {
56
+ let a = random( ( $n, $m) ) ;
57
+ $test( & a) ;
58
+ }
44
59
45
- #[ test]
46
- fn svd_4x3 ( ) {
47
- let a = random ( ( 4 , 3 ) ) ;
48
- test ( & a, 4 , 3 ) ;
60
+ #[ test]
61
+ fn [ <svd_ $test _ $n x $m _t>] ( ) {
62
+ let a = random( ( $n, $m) . f( ) ) ;
63
+ $test( & a) ;
64
+ }
65
+ }
66
+ } ;
49
67
}
50
68
51
- #[ test]
52
- fn svd_4x3_t ( ) {
53
- let a = random ( ( 4 , 3 ) . f ( ) ) ;
54
- test ( & a, 4 , 3 ) ;
55
- }
69
+ test_svd_impl ! ( test, 3 , 3 ) ;
70
+ test_svd_impl ! ( test_no_vt, 3 , 3 ) ;
71
+ test_svd_impl ! ( test_no_u, 3 , 3 ) ;
72
+ test_svd_impl ! ( test_diag_only, 3 , 3 ) ;
73
+ test_svd_impl ! ( test, 4 , 3 ) ;
74
+ test_svd_impl ! ( test_no_vt, 4 , 3 ) ;
75
+ test_svd_impl ! ( test_no_u, 4 , 3 ) ;
76
+ test_svd_impl ! ( test_diag_only, 4 , 3 ) ;
77
+ test_svd_impl ! ( test, 3 , 4 ) ;
78
+ test_svd_impl ! ( test_no_vt, 3 , 4 ) ;
79
+ test_svd_impl ! ( test_no_u, 3 , 4 ) ;
80
+ test_svd_impl ! ( test_diag_only, 3 , 4 ) ;
0 commit comments