Skip to content

Commit 2924f2e

Browse files
Andrewjturner314
Andrew
authored andcommitted
Support zero-length axis in .map_axis/_mut() (#612)
1 parent 47b2691 commit 2924f2e

File tree

3 files changed

+47
-15
lines changed

3 files changed

+47
-15
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ serde = { version = "1.0", optional = true }
4747
defmac = "0.2"
4848
quickcheck = { version = "0.7.2", default-features = false }
4949
rawpointer = "0.1"
50+
itertools = { version = "0.7.0", default-features = false, features = ["use_std"] }
5051
approx = "0.3"
5152

5253
[features]

src/impl_methods.rs

+24-14
Original file line numberDiff line numberDiff line change
@@ -2102,13 +2102,18 @@ where
21022102
{
21032103
let view_len = self.len_of(axis);
21042104
let view_stride = self.strides.axis(axis);
2105-
// use the 0th subview as a map to each 1d array view extended from
2106-
// the 0th element.
2107-
self.index_axis(axis, 0).map(|first_elt| {
2108-
unsafe {
2109-
mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
2110-
}
2111-
})
2105+
if view_len == 0 {
2106+
let new_dim = self.dim.remove_axis(axis);
2107+
Array::from_shape_fn(new_dim, move |_| mapping(ArrayView::from(&[])))
2108+
} else {
2109+
// use the 0th subview as a map to each 1d array view extended from
2110+
// the 0th element.
2111+
self.index_axis(axis, 0).map(|first_elt| {
2112+
unsafe {
2113+
mapping(ArrayView::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
2114+
}
2115+
})
2116+
}
21122117
}
21132118

21142119
/// Reduce the values along an axis into just one value, producing a new
@@ -2130,12 +2135,17 @@ where
21302135
{
21312136
let view_len = self.len_of(axis);
21322137
let view_stride = self.strides.axis(axis);
2133-
// use the 0th subview as a map to each 1d array view extended from
2134-
// the 0th element.
2135-
self.index_axis_mut(axis, 0).map_mut(|first_elt: &mut A| {
2136-
unsafe {
2137-
mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
2138-
}
2139-
})
2138+
if view_len == 0 {
2139+
let new_dim = self.dim.remove_axis(axis);
2140+
Array::from_shape_fn(new_dim, move |_| mapping(ArrayViewMut::from(&mut [])))
2141+
} else {
2142+
// use the 0th subview as a map to each 1d array view extended from
2143+
// the 0th element.
2144+
self.index_axis_mut(axis, 0).map_mut(|first_elt| {
2145+
unsafe {
2146+
mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
2147+
}
2148+
})
2149+
}
21402150
}
21412151
}

tests/array.rs

+22-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use ndarray::{
1313
};
1414
use ndarray::indices;
1515
use defmac::defmac;
16-
use itertools::{enumerate, zip};
16+
use itertools::{enumerate, zip, Itertools};
1717

1818
macro_rules! assert_panics {
1919
($body:expr) => {
@@ -1833,6 +1833,27 @@ fn test_map_axis() {
18331833
let c = a.map_axis(Axis(1), |view| view.sum());
18341834
let answer2 = arr1(&[6, 15, 24, 33]);
18351835
assert_eq!(c, answer2);
1836+
1837+
// Test zero-length axis case
1838+
let arr = Array3::<f32>::zeros((3, 0, 4));
1839+
let mut counter = 0;
1840+
let result = arr.map_axis(Axis(1), |x| {
1841+
assert_eq!(x.shape(), &[0]);
1842+
counter += 1;
1843+
counter
1844+
});
1845+
assert_eq!(result.shape(), &[3, 4]);
1846+
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
1847+
1848+
let mut arr = Array3::<f32>::zeros((3, 0, 4));
1849+
let mut counter = 0;
1850+
let result = arr.map_axis_mut(Axis(1), |x| {
1851+
assert_eq!(x.shape(), &[0]);
1852+
counter += 1;
1853+
counter
1854+
});
1855+
assert_eq!(result.shape(), &[3, 4]);
1856+
itertools::assert_equal(result.iter().cloned().sorted(), 1..=3 * 4);
18361857
}
18371858

18381859
#[test]

0 commit comments

Comments
 (0)