Skip to content

Commit 2d9b15b

Browse files
committed
Add .sum(Axis) benchmarks
1 parent 4162bab commit 2d9b15b

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

benches/bench1.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use rblas::matrix::Matrix;
1212
use ndarray::{
1313
OwnedArray,
1414
Axis,
15+
Ix,
1516
};
1617
use ndarray::{arr0, arr1, arr2};
1718

@@ -642,9 +643,33 @@ fn dot_extended(bench: &mut test::Bencher) {
642643
})
643644
}
644645

646+
const MEAN_SUM_N: usize = 127;
647+
648+
fn range_mat(m: Ix, n: Ix) -> OwnedArray<f32, (Ix, Ix)> {
649+
assert!(m * n != 0);
650+
OwnedArray::linspace(0., (m * n - 1) as f32, m * n).into_shape((m, n)).unwrap()
651+
}
652+
645653
#[bench]
646-
fn means(bench: &mut test::Bencher) {
647-
let a = OwnedArray::from_iter(0..100_000i64);
648-
let a = a.into_shape((100, 1000)).unwrap();
654+
fn mean_axis0(bench: &mut test::Bencher) {
655+
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
649656
bench.iter(|| a.mean(Axis(0)));
650657
}
658+
659+
#[bench]
660+
fn mean_axis1(bench: &mut test::Bencher) {
661+
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
662+
bench.iter(|| a.mean(Axis(1)));
663+
}
664+
665+
#[bench]
666+
fn sum_axis0(bench: &mut test::Bencher) {
667+
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
668+
bench.iter(|| a.sum(Axis(0)));
669+
}
670+
671+
#[bench]
672+
fn sum_axis1(bench: &mut test::Bencher) {
673+
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
674+
bench.iter(|| a.sum(Axis(1)));
675+
}

0 commit comments

Comments
 (0)