Skip to content

Commit b75d438

Browse files
committed
FEAT: New method names sum → sum_axis, mean → mean_axis
Old names are deprecated. This standardizes on a *_axis convention for these methods (more: map_axis, fold_axis). This makes room for .scalar_sum to take the name sum at a later date, and it makes room for introducing both std and std_axis for standard dev.
1 parent 5c5bdac commit b75d438

File tree

7 files changed

+42
-24
lines changed

7 files changed

+42
-24
lines changed

benches/bench1.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -961,23 +961,23 @@ fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
961961
#[bench]
962962
fn mean_axis0(bench: &mut test::Bencher) {
963963
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
964-
bench.iter(|| a.mean(Axis(0)));
964+
bench.iter(|| a.mean_axis(Axis(0)));
965965
}
966966

967967
#[bench]
968968
fn mean_axis1(bench: &mut test::Bencher) {
969969
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
970-
bench.iter(|| a.mean(Axis(1)));
970+
bench.iter(|| a.mean_axis(Axis(1)));
971971
}
972972

973973
#[bench]
974974
fn sum_axis0(bench: &mut test::Bencher) {
975975
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
976-
bench.iter(|| a.sum(Axis(0)));
976+
bench.iter(|| a.sum_axis(Axis(0)));
977977
}
978978

979979
#[bench]
980980
fn sum_axis1(bench: &mut test::Bencher) {
981981
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
982-
bench.iter(|| a.sum(Axis(1)));
982+
bench.iter(|| a.sum_axis(Axis(1)));
983983
}

examples/column_standardize.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ fn main() {
2525
[ 2., 2., 2.]];
2626

2727
println!("{:8.4}", data);
28-
println!("{:8.4} (Mean axis=0)", data.mean(Axis(0)));
28+
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)));
2929

30-
data -= &data.mean(Axis(0));
30+
data -= &data.mean_axis(Axis(0));
3131
println!("{:8.4}", data);
3232

3333
data /= &std(&data, Axis(0));

src/numeric/impl_numeric.rs

+26-8
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ impl<A, S, D> ArrayBase<S, D>
5858
/// let a = arr2(&[[1., 2.],
5959
/// [3., 4.]]);
6060
/// assert!(
61-
/// a.sum(Axis(0)) == aview1(&[4., 6.]) &&
62-
/// a.sum(Axis(1)) == aview1(&[3., 7.]) &&
61+
/// a.sum_axis(Axis(0)) == aview1(&[4., 6.]) &&
62+
/// a.sum_axis(Axis(1)) == aview1(&[3., 7.]) &&
6363
///
64-
/// a.sum(Axis(0)).sum(Axis(0)) == aview0(&10.)
64+
/// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&10.)
6565
/// );
6666
/// ```
6767
///
6868
/// **Panics** if `axis` is out of bounds.
69-
pub fn sum(&self, axis: Axis) -> Array<A, D::Smaller>
69+
pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
7070
where A: Clone + Zero + Add<Output=A>,
7171
D: RemoveAxis,
7272
{
@@ -88,6 +88,15 @@ impl<A, S, D> ArrayBase<S, D>
8888
res
8989
}
9090

91+
/// Old name for `sum_axis`.
92+
#[deprecated(note="Use new name .sum_axis()")]
93+
pub fn sum(&self, axis: Axis) -> Array<A, D::Smaller>
94+
where A: Clone + Zero + Add<Output=A>,
95+
D: RemoveAxis,
96+
{
97+
self.sum_axis(axis)
98+
}
99+
91100
/// Return mean along `axis`.
92101
///
93102
/// **Panics** if `axis` is out of bounds.
@@ -98,23 +107,32 @@ impl<A, S, D> ArrayBase<S, D>
98107
/// let a = arr2(&[[1., 2.],
99108
/// [3., 4.]]);
100109
/// assert!(
101-
/// a.mean(Axis(0)) == aview1(&[2.0, 3.0]) &&
102-
/// a.mean(Axis(1)) == aview1(&[1.5, 3.5])
110+
/// a.mean_axis(Axis(0)) == aview1(&[2.0, 3.0]) &&
111+
/// a.mean_axis(Axis(1)) == aview1(&[1.5, 3.5])
103112
/// );
104113
/// ```
105-
pub fn mean(&self, axis: Axis) -> Array<A, D::Smaller>
114+
pub fn mean_axis(&self, axis: Axis) -> Array<A, D::Smaller>
106115
where A: LinalgScalar,
107116
D: RemoveAxis,
108117
{
109118
let n = self.shape().axis(axis);
110-
let sum = self.sum(axis);
119+
let sum = self.sum_axis(axis);
111120
let mut cnt = A::one();
112121
for _ in 1..n {
113122
cnt = cnt + A::one();
114123
}
115124
sum / &aview0(&cnt)
116125
}
117126

127+
/// Old name for `mean_axis`.
128+
#[deprecated(note="Use new name .mean_axis()")]
129+
pub fn mean(&self, axis: Axis) -> Array<A, D::Smaller>
130+
where A: LinalgScalar,
131+
D: RemoveAxis,
132+
{
133+
self.mean_axis(axis)
134+
}
135+
118136
/// Return `true` if the arrays' elementwise differences are all within
119137
/// the given absolute tolerance, `false` otherwise.
120138
///

src/zip/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,8 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> {
444444
/// *e = row.scalar_sum();
445445
/// });
446446
///
447-
/// // Check the result against the built in `.sum()` along axis 1.
448-
/// assert_eq!(e, a.sum(Axis(1)));
447+
/// // Check the result against the built in `.sum_axis()` along axis 1.
448+
/// assert_eq!(e, a.sum_axis(Axis(1)));
449449
/// ```
450450
#[derive(Debug, Clone)]
451451
pub struct Zip<Parts, D> {

tests/array.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ fn assign()
340340
fn sum_mean()
341341
{
342342
let a = arr2(&[[1., 2.], [3., 4.]]);
343-
assert_eq!(a.sum(Axis(0)), arr1(&[4., 6.]));
344-
assert_eq!(a.sum(Axis(1)), arr1(&[3., 7.]));
345-
assert_eq!(a.mean(Axis(0)), arr1(&[2., 3.]));
346-
assert_eq!(a.mean(Axis(1)), arr1(&[1.5, 3.5]));
347-
assert_eq!(a.sum(Axis(1)).sum(Axis(0)), arr0(10.));
348-
assert_eq!(a.view().mean(Axis(1)), aview1(&[1.5, 3.5]));
343+
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
344+
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
345+
assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.]));
346+
assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5]));
347+
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
348+
assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5]));
349349
assert_eq!(a.scalar_sum(), 10.);
350350
}
351351

tests/azip.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ fn test_azip2_sum() {
5252
let ax = Axis(i);
5353
let mut b = Array::zeros(c.len_of(ax));
5454
azip!(mut b, ref c (c.axis_iter(ax)) in { *b = c.scalar_sum() });
55-
assert!(b.all_close(&c.sum(Axis(1 - i)), 1e-6));
55+
assert!(b.all_close(&c.sum_axis(Axis(1 - i)), 1e-6));
5656
}
5757
}
5858

tests/complex.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ fn complex_mat_mul()
2222
let r = a.dot(&e);
2323
println!("{}", a);
2424
assert_eq!(r, a);
25-
assert_eq!(a.mean(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
25+
assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
2626
}

0 commit comments

Comments
 (0)