1
1
extern crate approx;
2
2
use std:: f64;
3
- use ndarray:: { Array1 , array } ;
3
+ use ndarray:: { array , Axis , aview1 , aview2 , aview0 , arr0 , arr1 , arr2 , Array , Array1 , Array2 , Array3 } ;
4
4
use approx:: abs_diff_eq;
5
5
6
6
#[ test]
@@ -32,4 +32,172 @@ fn test_mean_with_array_of_floats() {
32
32
// Computed using NumPy
33
33
let expected_mean = 0.5475494059146699 ;
34
34
abs_diff_eq ! ( a. mean( ) . unwrap( ) , expected_mean, epsilon = f64 :: EPSILON ) ;
35
- }
35
+ }
36
+
37
+ #[ test]
38
+ fn sum_mean ( )
39
+ {
40
+ let a = arr2 ( & [ [ 1. , 2. ] , [ 3. , 4. ] ] ) ;
41
+ assert_eq ! ( a. sum_axis( Axis ( 0 ) ) , arr1( & [ 4. , 6. ] ) ) ;
42
+ assert_eq ! ( a. sum_axis( Axis ( 1 ) ) , arr1( & [ 3. , 7. ] ) ) ;
43
+ assert_eq ! ( a. mean_axis( Axis ( 0 ) ) , Some ( arr1( & [ 2. , 3. ] ) ) ) ;
44
+ assert_eq ! ( a. mean_axis( Axis ( 1 ) ) , Some ( arr1( & [ 1.5 , 3.5 ] ) ) ) ;
45
+ assert_eq ! ( a. sum_axis( Axis ( 1 ) ) . sum_axis( Axis ( 0 ) ) , arr0( 10. ) ) ;
46
+ assert_eq ! ( a. view( ) . mean_axis( Axis ( 1 ) ) . unwrap( ) , aview1( & [ 1.5 , 3.5 ] ) ) ;
47
+ assert_eq ! ( a. sum( ) , 10. ) ;
48
+ }
49
+
50
+ #[ test]
51
+ fn sum_mean_empty ( ) {
52
+ assert_eq ! ( Array3 :: <f32 >:: ones( ( 2 , 0 , 3 ) ) . sum( ) , 0. ) ;
53
+ assert_eq ! ( Array1 :: <f32 >:: ones( 0 ) . sum_axis( Axis ( 0 ) ) , arr0( 0. ) ) ;
54
+ assert_eq ! (
55
+ Array3 :: <f32 >:: ones( ( 2 , 0 , 3 ) ) . sum_axis( Axis ( 1 ) ) ,
56
+ Array :: zeros( ( 2 , 3 ) ) ,
57
+ ) ;
58
+ let a = Array1 :: < f32 > :: ones ( 0 ) . mean_axis ( Axis ( 0 ) ) ;
59
+ assert_eq ! ( a, None ) ;
60
+ let a = Array3 :: < f32 > :: ones ( ( 2 , 0 , 3 ) ) . mean_axis ( Axis ( 1 ) ) ;
61
+ assert_eq ! ( a, None ) ;
62
+ }
63
+
64
+ #[ test]
65
+ fn var_axis ( ) {
66
+ let a = array ! [
67
+ [
68
+ [ -9.76 , -0.38 , 1.59 , 6.23 ] ,
69
+ [ -8.57 , -9.27 , 5.76 , 6.01 ] ,
70
+ [ -9.54 , 5.09 , 3.21 , 6.56 ] ,
71
+ ] ,
72
+ [
73
+ [ 8.23 , -9.63 , 3.76 , -3.48 ] ,
74
+ [ -5.46 , 5.86 , -2.81 , 1.35 ] ,
75
+ [ -1.08 , 4.66 , 8.34 , -0.73 ] ,
76
+ ] ,
77
+ ] ;
78
+ assert ! ( a. var_axis( Axis ( 0 ) , 1.5 ) . all_close(
79
+ & aview2( & [
80
+ [ 3.236401e+02 , 8.556250e+01 , 4.708900e+00 , 9.428410e+01 ] ,
81
+ [ 9.672100e+00 , 2.289169e+02 , 7.344490e+01 , 2.171560e+01 ] ,
82
+ [ 7.157160e+01 , 1.849000e-01 , 2.631690e+01 , 5.314410e+01 ]
83
+ ] ) ,
84
+ 1e-4 ,
85
+ ) ) ;
86
+ assert ! ( a. var_axis( Axis ( 1 ) , 1.7 ) . all_close(
87
+ & aview2( & [
88
+ [ 0.61676923 , 80.81092308 , 6.79892308 , 0.11789744 ] ,
89
+ [ 75.19912821 , 114.25235897 , 48.32405128 , 9.03020513 ] ,
90
+ ] ) ,
91
+ 1e-8 ,
92
+ ) ) ;
93
+ assert ! ( a. var_axis( Axis ( 2 ) , 2.3 ) . all_close(
94
+ & aview2( & [
95
+ [ 79.64552941 , 129.09663235 , 95.98929412 ] ,
96
+ [ 109.64952941 , 43.28758824 , 36.27439706 ] ,
97
+ ] ) ,
98
+ 1e-8 ,
99
+ ) ) ;
100
+
101
+ let b = array ! [ [ 1.1 , 2.3 , 4.7 ] ] ;
102
+ assert ! ( b. var_axis( Axis ( 0 ) , 0. ) . all_close( & aview1( & [ 0. , 0. , 0. ] ) , 1e-12 ) ) ;
103
+ assert ! ( b. var_axis( Axis ( 1 ) , 0. ) . all_close( & aview1( & [ 2.24 ] ) , 1e-12 ) ) ;
104
+
105
+ let c = array ! [ [ ] , [ ] ] ;
106
+ assert_eq ! ( c. var_axis( Axis ( 0 ) , 0. ) , aview1( & [ ] ) ) ;
107
+
108
+ let d = array ! [ 1.1 , 2.7 , 3.5 , 4.9 ] ;
109
+ assert ! ( d. var_axis( Axis ( 0 ) , 0. ) . all_close( & aview0( & 1.8875 ) , 1e-12 ) ) ;
110
+ }
111
+
112
+ #[ test]
113
+ fn std_axis ( ) {
114
+ let a = array ! [
115
+ [
116
+ [ 0.22935481 , 0.08030619 , 0.60827517 , 0.73684379 ] ,
117
+ [ 0.90339851 , 0.82859436 , 0.64020362 , 0.2774583 ] ,
118
+ [ 0.44485313 , 0.63316367 , 0.11005111 , 0.08656246 ]
119
+ ] ,
120
+ [
121
+ [ 0.28924665 , 0.44082454 , 0.59837736 , 0.41014531 ] ,
122
+ [ 0.08382316 , 0.43259439 , 0.1428889 , 0.44830176 ] ,
123
+ [ 0.51529756 , 0.70111616 , 0.20799415 , 0.91851457 ]
124
+ ] ,
125
+ ] ;
126
+ assert ! ( a. std_axis( Axis ( 0 ) , 1.5 ) . all_close(
127
+ & aview2( & [
128
+ [ 0.05989184 , 0.36051836 , 0.00989781 , 0.32669847 ] ,
129
+ [ 0.81957535 , 0.39599997 , 0.49731472 , 0.17084346 ] ,
130
+ [ 0.07044443 , 0.06795249 , 0.09794304 , 0.83195211 ] ,
131
+ ] ) ,
132
+ 1e-4 ,
133
+ ) ) ;
134
+ assert ! ( a. std_axis( Axis ( 1 ) , 1.7 ) . all_close(
135
+ & aview2( & [
136
+ [ 0.42698655 , 0.48139215 , 0.36874991 , 0.41458724 ] ,
137
+ [ 0.26769097 , 0.18941435 , 0.30555015 , 0.35118674 ] ,
138
+ ] ) ,
139
+ 1e-8 ,
140
+ ) ) ;
141
+ assert ! ( a. std_axis( Axis ( 2 ) , 2.3 ) . all_close(
142
+ & aview2( & [
143
+ [ 0.41117907 , 0.37130425 , 0.35332388 ] ,
144
+ [ 0.16905862 , 0.25304841 , 0.39978276 ] ,
145
+ ] ) ,
146
+ 1e-8 ,
147
+ ) ) ;
148
+
149
+ let b = array ! [ [ 100000. , 1. , 0.01 ] ] ;
150
+ assert ! ( b. std_axis( Axis ( 0 ) , 0. ) . all_close( & aview1( & [ 0. , 0. , 0. ] ) , 1e-12 ) ) ;
151
+ assert ! (
152
+ b. std_axis( Axis ( 1 ) , 0. ) . all_close( & aview1( & [ 47140.214021552769 ] ) , 1e-6 ) ,
153
+ ) ;
154
+
155
+ let c = array ! [ [ ] , [ ] ] ;
156
+ assert_eq ! ( c. std_axis( Axis ( 0 ) , 0. ) , aview1( & [ ] ) ) ;
157
+ }
158
+
159
+ #[ test]
160
+ #[ should_panic]
161
+ fn var_axis_negative_ddof ( ) {
162
+ let a = array ! [ 1. , 2. , 3. ] ;
163
+ a. var_axis ( Axis ( 0 ) , -1. ) ;
164
+ }
165
+
166
+ #[ test]
167
+ #[ should_panic]
168
+ fn var_axis_too_large_ddof ( ) {
169
+ let a = array ! [ 1. , 2. , 3. ] ;
170
+ a. var_axis ( Axis ( 0 ) , 4. ) ;
171
+ }
172
+
173
+ #[ test]
174
+ fn var_axis_nan_ddof ( ) {
175
+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 3 ) ) ;
176
+ let v = a. var_axis ( Axis ( 1 ) , :: std:: f64:: NAN ) ;
177
+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
178
+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
179
+ }
180
+
181
+ #[ test]
182
+ fn var_axis_empty_axis ( ) {
183
+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 0 ) ) ;
184
+ let v = a. var_axis ( Axis ( 1 ) , 0. ) ;
185
+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
186
+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
187
+ }
188
+
189
+ #[ test]
190
+ #[ should_panic]
191
+ fn std_axis_bad_dof ( ) {
192
+ let a = array ! [ 1. , 2. , 3. ] ;
193
+ a. std_axis ( Axis ( 0 ) , 4. ) ;
194
+ }
195
+
196
+ #[ test]
197
+ fn std_axis_empty_axis ( ) {
198
+ let a = Array2 :: < f64 > :: zeros ( ( 2 , 0 ) ) ;
199
+ let v = a. std_axis ( Axis ( 1 ) , 0. ) ;
200
+ assert_eq ! ( v. shape( ) , & [ 2 ] ) ;
201
+ v. mapv ( |x| assert ! ( x. is_nan( ) ) ) ;
202
+ }
203
+
0 commit comments