Skip to content

Commit 47b2691

Browse files
authored
Merge pull request #580 from LukeMathWalker/mean
Mean
2 parents c569f75 + 64b3da7 commit 47b2691

File tree

6 files changed

+250
-181
lines changed

6 files changed

+250
-181
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ serde = { version = "1.0", optional = true }
4747
defmac = "0.2"
4848
quickcheck = { version = "0.7.2", default-features = false }
4949
rawpointer = "0.1"
50+
approx = "0.3"
5051

5152
[features]
5253
# Enable blas usage

examples/column_standardize.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ fn main() {
2323
[ 2., 2., 2.]];
2424

2525
println!("{:8.4}", data);
26-
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)));
26+
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap());
2727

28-
data -= &data.mean_axis(Axis(0));
28+
data -= &data.mean_axis(Axis(0)).unwrap();
2929
println!("{:8.4}", data);
3030

3131
data /= &std(&data, Axis(0));

src/numeric/impl_numeric.rs

+43-9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,33 @@ impl<A, S, D> ArrayBase<S, D>
4646
sum
4747
}
4848

49+
/// Returns the [arithmetic mean] x̅ of all elements in the array:
50+
///
51+
/// ```text
52+
/// 1 n
53+
/// x̅ = ― ∑ xᵢ
54+
/// n i=1
55+
/// ```
56+
///
57+
/// If the array is empty, `None` is returned.
58+
///
59+
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
60+
///
61+
/// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
62+
pub fn mean(&self) -> Option<A>
63+
where
64+
A: Clone + FromPrimitive + Add<Output=A> + Div<Output=A> + Zero
65+
{
66+
let n_elements = self.len();
67+
if n_elements == 0 {
68+
None
69+
} else {
70+
let n_elements = A::from_usize(n_elements)
71+
.expect("Converting number of elements to `A` must not fail.");
72+
Some(self.sum() / n_elements)
73+
}
74+
}
75+
4976
/// Return the sum of all elements in the array.
5077
///
5178
/// *This method has been renamed to `.sum()` and will be deprecated in the
@@ -123,8 +150,9 @@ impl<A, S, D> ArrayBase<S, D>
123150

124151
/// Return mean along `axis`.
125152
///
126-
/// **Panics** if `axis` is out of bounds, if the length of the axis is
127-
/// zero and division by zero panics for type `A`, or if `A::from_usize()`
153+
/// Return `None` if the length of the axis is zero.
154+
///
155+
/// **Panics** if `axis` is out of bounds or if `A::from_usize()`
128156
/// fails for the axis length.
129157
///
130158
/// ```
@@ -133,19 +161,25 @@ impl<A, S, D> ArrayBase<S, D>
133161
/// let a = arr2(&[[1., 2., 3.],
134162
/// [4., 5., 6.]]);
135163
/// assert!(
136-
/// a.mean_axis(Axis(0)) == aview1(&[2.5, 3.5, 4.5]) &&
137-
/// a.mean_axis(Axis(1)) == aview1(&[2., 5.]) &&
164+
/// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
165+
/// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
138166
///
139-
/// a.mean_axis(Axis(0)).mean_axis(Axis(0)) == aview0(&3.5)
167+
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
140168
/// );
141169
/// ```
142-
pub fn mean_axis(&self, axis: Axis) -> Array<A, D::Smaller>
170+
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
143171
where A: Clone + Zero + FromPrimitive + Add<Output=A> + Div<Output=A>,
144172
D: RemoveAxis,
145173
{
146-
let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail.");
147-
let sum = self.sum_axis(axis);
148-
sum / &aview0(&n)
174+
let axis_length = self.len_of(axis);
175+
if axis_length == 0 {
176+
None
177+
} else {
178+
let axis_length = A::from_usize(axis_length)
179+
.expect("Converting axis length to `A` must not fail.");
180+
let sum = self.sum_axis(axis);
181+
Some(sum / &aview0(&axis_length))
182+
}
149183
}
150184

151185
/// Return variance along `axis`.

tests/array.rs

-169
Original file line numberDiff line numberDiff line change
@@ -925,175 +925,6 @@ fn assign()
925925
assert_eq!(a, arr2(&[[0, 0], [3, 4]]));
926926
}
927927

928-
#[test]
929-
fn sum_mean()
930-
{
931-
let a = arr2(&[[1., 2.], [3., 4.]]);
932-
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
933-
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
934-
assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.]));
935-
assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5]));
936-
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
937-
assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5]));
938-
assert_eq!(a.sum(), 10.);
939-
}
940-
941-
#[test]
942-
fn sum_mean_empty() {
943-
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
944-
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
945-
assert_eq!(
946-
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
947-
Array::zeros((2, 3)),
948-
);
949-
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
950-
assert_eq!(a.shape(), &[]);
951-
assert!(a[()].is_nan());
952-
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
953-
assert_eq!(a.shape(), &[2, 3]);
954-
a.mapv(|x| assert!(x.is_nan()));
955-
}
956-
957-
#[test]
958-
fn var_axis() {
959-
let a = array![
960-
[
961-
[-9.76, -0.38, 1.59, 6.23],
962-
[-8.57, -9.27, 5.76, 6.01],
963-
[-9.54, 5.09, 3.21, 6.56],
964-
],
965-
[
966-
[ 8.23, -9.63, 3.76, -3.48],
967-
[-5.46, 5.86, -2.81, 1.35],
968-
[-1.08, 4.66, 8.34, -0.73],
969-
],
970-
];
971-
assert!(a.var_axis(Axis(0), 1.5).all_close(
972-
&aview2(&[
973-
[3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01],
974-
[9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01],
975-
[7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01]
976-
]),
977-
1e-4,
978-
));
979-
assert!(a.var_axis(Axis(1), 1.7).all_close(
980-
&aview2(&[
981-
[0.61676923, 80.81092308, 6.79892308, 0.11789744],
982-
[75.19912821, 114.25235897, 48.32405128, 9.03020513],
983-
]),
984-
1e-8,
985-
));
986-
assert!(a.var_axis(Axis(2), 2.3).all_close(
987-
&aview2(&[
988-
[ 79.64552941, 129.09663235, 95.98929412],
989-
[109.64952941, 43.28758824, 36.27439706],
990-
]),
991-
1e-8,
992-
));
993-
994-
let b = array![[1.1, 2.3, 4.7]];
995-
assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
996-
assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12));
997-
998-
let c = array![[], []];
999-
assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[]));
1000-
1001-
let d = array![1.1, 2.7, 3.5, 4.9];
1002-
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
1003-
}
1004-
1005-
#[test]
1006-
fn std_axis() {
1007-
let a = array![
1008-
[
1009-
[ 0.22935481, 0.08030619, 0.60827517, 0.73684379],
1010-
[ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ],
1011-
[ 0.44485313, 0.63316367, 0.11005111, 0.08656246]
1012-
],
1013-
[
1014-
[ 0.28924665, 0.44082454, 0.59837736, 0.41014531],
1015-
[ 0.08382316, 0.43259439, 0.1428889 , 0.44830176],
1016-
[ 0.51529756, 0.70111616, 0.20799415, 0.91851457]
1017-
],
1018-
];
1019-
assert!(a.std_axis(Axis(0), 1.5).all_close(
1020-
&aview2(&[
1021-
[ 0.05989184, 0.36051836, 0.00989781, 0.32669847],
1022-
[ 0.81957535, 0.39599997, 0.49731472, 0.17084346],
1023-
[ 0.07044443, 0.06795249, 0.09794304, 0.83195211],
1024-
]),
1025-
1e-4,
1026-
));
1027-
assert!(a.std_axis(Axis(1), 1.7).all_close(
1028-
&aview2(&[
1029-
[ 0.42698655, 0.48139215, 0.36874991, 0.41458724],
1030-
[ 0.26769097, 0.18941435, 0.30555015, 0.35118674],
1031-
]),
1032-
1e-8,
1033-
));
1034-
assert!(a.std_axis(Axis(2), 2.3).all_close(
1035-
&aview2(&[
1036-
[ 0.41117907, 0.37130425, 0.35332388],
1037-
[ 0.16905862, 0.25304841, 0.39978276],
1038-
]),
1039-
1e-8,
1040-
));
1041-
1042-
let b = array![[100000., 1., 0.01]];
1043-
assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
1044-
assert!(
1045-
b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6),
1046-
);
1047-
1048-
let c = array![[], []];
1049-
assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[]));
1050-
}
1051-
1052-
#[test]
1053-
#[should_panic]
1054-
fn var_axis_negative_ddof() {
1055-
let a = array![1., 2., 3.];
1056-
a.var_axis(Axis(0), -1.);
1057-
}
1058-
1059-
#[test]
1060-
#[should_panic]
1061-
fn var_axis_too_large_ddof() {
1062-
let a = array![1., 2., 3.];
1063-
a.var_axis(Axis(0), 4.);
1064-
}
1065-
1066-
#[test]
1067-
fn var_axis_nan_ddof() {
1068-
let a = Array2::<f64>::zeros((2, 3));
1069-
let v = a.var_axis(Axis(1), ::std::f64::NAN);
1070-
assert_eq!(v.shape(), &[2]);
1071-
v.mapv(|x| assert!(x.is_nan()));
1072-
}
1073-
1074-
#[test]
1075-
fn var_axis_empty_axis() {
1076-
let a = Array2::<f64>::zeros((2, 0));
1077-
let v = a.var_axis(Axis(1), 0.);
1078-
assert_eq!(v.shape(), &[2]);
1079-
v.mapv(|x| assert!(x.is_nan()));
1080-
}
1081-
1082-
#[test]
1083-
#[should_panic]
1084-
fn std_axis_bad_dof() {
1085-
let a = array![1., 2., 3.];
1086-
a.std_axis(Axis(0), 4.);
1087-
}
1088-
1089-
#[test]
1090-
fn std_axis_empty_axis() {
1091-
let a = Array2::<f64>::zeros((2, 0));
1092-
let v = a.std_axis(Axis(1), 0.);
1093-
assert_eq!(v.shape(), &[2]);
1094-
v.mapv(|x| assert!(x.is_nan()));
1095-
}
1096-
1097928
#[test]
1098929
fn iter_size_hint()
1099930
{

tests/complex.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ fn complex_mat_mul()
2222
let r = a.dot(&e);
2323
println!("{}", a);
2424
assert_eq!(r, a);
25-
assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
25+
assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
2626
}

0 commit comments

Comments
 (0)