Skip to content

Commit bc795b8

Browse files
authored
Merge pull request #669 from jturner314/fix-axis-iter
Fix axis iterators
2 parents ce80d38 + 1443df8 commit bc795b8

File tree

3 files changed

+264
-44
lines changed

3 files changed

+264
-44
lines changed

src/impl_methods.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,7 @@ where
10051005
/// The last view may have less elements if `size` does not divide
10061006
/// the axis' dimension.
10071007
///
1008-
/// **Panics** if `axis` is out of bounds.
1008+
/// **Panics** if `axis` is out of bounds or if `size` is zero.
10091009
///
10101010
/// ```
10111011
/// use ndarray::Array;
@@ -1036,7 +1036,7 @@ where
10361036
///
10371037
/// Iterator element is `ArrayViewMut<A, D>`
10381038
///
1039-
/// **Panics** if `axis` is out of bounds.
1039+
/// **Panics** if `axis` is out of bounds or if `size` is zero.
10401040
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D>
10411041
where
10421042
S: DataMut,

src/iterators/mod.rs

+90-41
Original file line numberDiff line numberDiff line change
@@ -738,11 +738,19 @@ where
738738

739739
#[derive(Debug)]
740740
pub struct AxisIterCore<A, D> {
741+
/// Index along the axis of the value of `.next()`, relative to the start
742+
/// of the axis.
741743
index: Ix,
742-
len: Ix,
744+
/// (Exclusive) upper bound on `index`. Initially, this is equal to the
745+
/// length of the axis.
746+
end: Ix,
747+
/// Stride along the axis (offset between consecutive pointers).
743748
stride: Ixs,
749+
/// Shape of the iterator's items.
744750
inner_dim: D,
751+
/// Strides of the iterator's items.
745752
inner_strides: D,
753+
/// Pointer corresponding to `index == 0`.
746754
ptr: *mut A,
747755
}
748756

@@ -751,7 +759,7 @@ clone_bounds!(
751759
AxisIterCore[A, D] {
752760
@copy {
753761
index,
754-
len,
762+
end,
755763
stride,
756764
ptr,
757765
}
@@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
767775
Di: RemoveAxis<Smaller = D>,
768776
S: Data<Elem = A>,
769777
{
770-
let shape = v.shape()[axis.index()];
771-
let stride = v.strides()[axis.index()];
772778
AxisIterCore {
773779
index: 0,
774-
len: shape,
775-
stride,
780+
end: v.len_of(axis),
781+
stride: v.stride_of(axis),
776782
inner_dim: v.dim.remove_axis(axis),
777783
inner_strides: v.strides.remove_axis(axis),
778784
ptr: v.ptr,
779785
}
780786
}
781787

788+
#[inline]
782789
unsafe fn offset(&self, index: usize) -> *mut A {
783790
debug_assert!(
784-
index <= self.len,
785-
"index={}, len={}, stride={}",
791+
index < self.end,
792+
"index={}, end={}, stride={}",
786793
index,
787-
self.len,
794+
self.end,
788795
self.stride
789796
);
790797
self.ptr.offset(index as isize * self.stride)
791798
}
792799

793-
/// Split the iterator at index, yielding two disjoint iterators.
800+
/// Splits the iterator at `index`, yielding two disjoint iterators.
794801
///
795-
/// **Panics** if `index` is strictly greater than the iterator's length
802+
/// `index` is relative to the current state of the iterator (which is not
803+
/// necessarily the start of the axis).
804+
///
805+
/// **Panics** if `index` is strictly greater than the iterator's remaining
806+
/// length.
796807
fn split_at(self, index: usize) -> (Self, Self) {
797-
assert!(index <= self.len);
798-
let right_ptr = if index != self.len {
799-
unsafe { self.offset(index) }
800-
} else {
801-
self.ptr
802-
};
808+
assert!(index <= self.len());
809+
let mid = self.index + index;
803810
let left = AxisIterCore {
804-
index: 0,
805-
len: index,
811+
index: self.index,
812+
end: mid,
806813
stride: self.stride,
807814
inner_dim: self.inner_dim.clone(),
808815
inner_strides: self.inner_strides.clone(),
809816
ptr: self.ptr,
810817
};
811818
let right = AxisIterCore {
812-
index: 0,
813-
len: self.len - index,
819+
index: mid,
820+
end: self.end,
814821
stride: self.stride,
815822
inner_dim: self.inner_dim,
816823
inner_strides: self.inner_strides,
817-
ptr: right_ptr,
824+
ptr: self.ptr,
818825
};
819826
(left, right)
820827
}
@@ -827,7 +834,7 @@ where
827834
type Item = *mut A;
828835

829836
fn next(&mut self) -> Option<Self::Item> {
830-
if self.index >= self.len {
837+
if self.index >= self.end {
831838
None
832839
} else {
833840
let ptr = unsafe { self.offset(self.index) };
@@ -837,7 +844,7 @@ where
837844
}
838845

839846
fn size_hint(&self) -> (usize, Option<usize>) {
840-
let len = self.len - self.index;
847+
let len = self.len();
841848
(len, Some(len))
842849
}
843850
}
@@ -847,16 +854,25 @@ where
847854
D: Dimension,
848855
{
849856
fn next_back(&mut self) -> Option<Self::Item> {
850-
if self.index >= self.len {
857+
if self.index >= self.end {
851858
None
852859
} else {
853-
self.len -= 1;
854-
let ptr = unsafe { self.offset(self.len) };
860+
let ptr = unsafe { self.offset(self.end - 1) };
861+
self.end -= 1;
855862
Some(ptr)
856863
}
857864
}
858865
}
859866

867+
impl<A, D> ExactSizeIterator for AxisIterCore<A, D>
868+
where
869+
D: Dimension,
870+
{
871+
fn len(&self) -> usize {
872+
self.end - self.index
873+
}
874+
}
875+
860876
/// An iterator that traverses over an axis and
861877
/// and yields each subview.
862878
///
@@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
899915
}
900916
}
901917

902-
/// Split the iterator at index, yielding two disjoint iterators.
918+
/// Splits the iterator at `index`, yielding two disjoint iterators.
903919
///
904-
/// **Panics** if `index` is strictly greater than the iterator's length
920+
/// `index` is relative to the current state of the iterator (which is not
921+
/// necessarily the start of the axis).
922+
///
923+
/// **Panics** if `index` is strictly greater than the iterator's remaining
924+
/// length.
905925
pub fn split_at(self, index: usize) -> (Self, Self) {
906926
let (left, right) = self.iter.split_at(index);
907927
(
@@ -946,7 +966,7 @@ where
946966
D: Dimension,
947967
{
948968
fn len(&self) -> usize {
949-
self.size_hint().0
969+
self.iter.len()
950970
}
951971
}
952972

@@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
9811001
}
9821002
}
9831003

984-
/// Split the iterator at index, yielding two disjoint iterators.
1004+
/// Splits the iterator at `index`, yielding two disjoint iterators.
9851005
///
986-
/// **Panics** if `index` is strictly greater than the iterator's length
1006+
/// `index` is relative to the current state of the iterator (which is not
1007+
/// necessarily the start of the axis).
1008+
///
1009+
/// **Panics** if `index` is strictly greater than the iterator's remaining
1010+
/// length.
9871011
pub fn split_at(self, index: usize) -> (Self, Self) {
9881012
let (left, right) = self.iter.split_at(index);
9891013
(
@@ -1028,7 +1052,7 @@ where
10281052
D: Dimension,
10291053
{
10301054
fn len(&self) -> usize {
1031-
self.size_hint().0
1055+
self.iter.len()
10321056
}
10331057
}
10341058

@@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
10481072
}
10491073
#[doc(hidden)]
10501074
fn as_ptr(&self) -> Self::Ptr {
1051-
self.iter.ptr
1075+
if self.len() > 0 {
1076+
// `self.iter.index` is guaranteed to be in-bounds if any of the
1077+
// iterator remains (i.e. if `self.len() > 0`).
1078+
unsafe { self.iter.offset(self.iter.index) }
1079+
} else {
1080+
// In this case, `self.iter.index` may be past the end, so we must
1081+
// not call `.offset()`. It's okay to return a dangling pointer
1082+
// because it will never be used in the length 0 case.
1083+
std::ptr::NonNull::dangling().as_ptr()
1084+
}
10521085
}
10531086

10541087
fn contiguous_stride(&self) -> isize {
@@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
10651098
}
10661099
#[doc(hidden)]
10671100
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
1068-
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
1101+
self.iter.offset(self.iter.index + i[0])
10691102
}
10701103

10711104
#[doc(hidden)]
@@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
10961129
}
10971130
#[doc(hidden)]
10981131
fn as_ptr(&self) -> Self::Ptr {
1099-
self.iter.ptr
1132+
if self.len() > 0 {
1133+
// `self.iter.index` is guaranteed to be in-bounds if any of the
1134+
// iterator remains (i.e. if `self.len() > 0`).
1135+
unsafe { self.iter.offset(self.iter.index) }
1136+
} else {
1137+
// In this case, `self.iter.index` may be past the end, so we must
1138+
// not call `.offset()`. It's okay to return a dangling pointer
1139+
// because it will never be used in the length 0 case.
1140+
std::ptr::NonNull::dangling().as_ptr()
1141+
}
11001142
}
11011143

11021144
fn contiguous_stride(&self) -> isize {
@@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
11131155
}
11141156
#[doc(hidden)]
11151157
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
1116-
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
1158+
self.iter.offset(self.iter.index + i[0])
11171159
}
11181160

11191161
#[doc(hidden)]
@@ -1164,21 +1206,28 @@ clone_bounds!(
11641206
///
11651207
/// Returns an axis iterator with the correct stride to move between chunks,
11661208
/// the number of chunks, and the shape of the last chunk.
1209+
///
1210+
/// **Panics** if `size == 0`.
11671211
fn chunk_iter_parts<A, D: Dimension>(
11681212
v: ArrayView<'_, A, D>,
11691213
axis: Axis,
11701214
size: usize,
11711215
) -> (AxisIterCore<A, D>, usize, D) {
1216+
assert_ne!(size, 0, "Chunk size must be nonzero.");
11721217
let axis_len = v.len_of(axis);
1173-
let size = if size > axis_len { axis_len } else { size };
11741218
let n_whole_chunks = axis_len / size;
11751219
let chunk_remainder = axis_len % size;
11761220
let iter_len = if chunk_remainder == 0 {
11771221
n_whole_chunks
11781222
} else {
11791223
n_whole_chunks + 1
11801224
};
1181-
let stride = v.stride_of(axis) * size as isize;
1225+
let stride = if n_whole_chunks == 0 {
1226+
// This case avoids potential overflow when `size > axis_len`.
1227+
0
1228+
} else {
1229+
v.stride_of(axis) * size as isize
1230+
};
11821231

11831232
let axis = axis.index();
11841233
let mut inner_dim = v.dim.clone();
@@ -1193,7 +1242,7 @@ fn chunk_iter_parts<A, D: Dimension>(
11931242

11941243
let iter = AxisIterCore {
11951244
index: 0,
1196-
len: iter_len,
1245+
end: iter_len,
11971246
stride,
11981247
inner_dim,
11991248
inner_strides: v.strides,
@@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
12701319
D: Dimension,
12711320
{
12721321
fn next_back(&mut self) -> Option<Self::Item> {
1273-
let is_uneven = self.iter.len > self.n_whole_chunks;
1322+
let is_uneven = self.iter.end > self.n_whole_chunks;
12741323
let res = self.iter.next_back();
12751324
self.get_subview(res, is_uneven)
12761325
}

0 commit comments

Comments
 (0)