@@ -738,11 +738,19 @@ where
738
738
739
739
#[ derive( Debug ) ]
740
740
pub struct AxisIterCore < A , D > {
741
+ /// Index along the axis of the value of `.next()`, relative to the start
742
+ /// of the axis.
741
743
index : Ix ,
742
- len : Ix ,
744
+ /// (Exclusive) upper bound on `index`. Initially, this is equal to the
745
+ /// length of the axis.
746
+ end : Ix ,
747
+ /// Stride along the axis (offset between consecutive pointers).
743
748
stride : Ixs ,
749
+ /// Shape of the iterator's items.
744
750
inner_dim : D ,
751
+ /// Strides of the iterator's items.
745
752
inner_strides : D ,
753
+ /// Pointer corresponding to `index == 0`.
746
754
ptr : * mut A ,
747
755
}
748
756
@@ -751,7 +759,7 @@ clone_bounds!(
751
759
AxisIterCore [ A , D ] {
752
760
@copy {
753
761
index,
754
- len ,
762
+ end ,
755
763
stride,
756
764
ptr,
757
765
}
@@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
767
775
Di : RemoveAxis < Smaller = D > ,
768
776
S : Data < Elem = A > ,
769
777
{
770
- let shape = v. shape ( ) [ axis. index ( ) ] ;
771
- let stride = v. strides ( ) [ axis. index ( ) ] ;
772
778
AxisIterCore {
773
779
index : 0 ,
774
- len : shape ,
775
- stride,
780
+ end : v . len_of ( axis ) ,
781
+ stride : v . stride_of ( axis ) ,
776
782
inner_dim : v. dim . remove_axis ( axis) ,
777
783
inner_strides : v. strides . remove_axis ( axis) ,
778
784
ptr : v. ptr ,
779
785
}
780
786
}
781
787
788
+ #[ inline]
782
789
unsafe fn offset ( & self , index : usize ) -> * mut A {
783
790
debug_assert ! (
784
- index <= self . len ,
785
- "index={}, len ={}, stride={}" ,
791
+ index < self . end ,
792
+ "index={}, end ={}, stride={}" ,
786
793
index,
787
- self . len ,
794
+ self . end ,
788
795
self . stride
789
796
) ;
790
797
self . ptr . offset ( index as isize * self . stride )
791
798
}
792
799
793
- /// Split the iterator at index, yielding two disjoint iterators.
800
+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
794
801
///
795
- /// **Panics** if `index` is strictly greater than the iterator's length
802
+ /// `index` is relative to the current state of the iterator (which is not
803
+ /// necessarily the start of the axis).
804
+ ///
805
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
806
+ /// length.
796
807
fn split_at ( self , index : usize ) -> ( Self , Self ) {
797
- assert ! ( index <= self . len) ;
798
- let right_ptr = if index != self . len {
799
- unsafe { self . offset ( index) }
800
- } else {
801
- self . ptr
802
- } ;
808
+ assert ! ( index <= self . len( ) ) ;
809
+ let mid = self . index + index;
803
810
let left = AxisIterCore {
804
- index : 0 ,
805
- len : index ,
811
+ index : self . index ,
812
+ end : mid ,
806
813
stride : self . stride ,
807
814
inner_dim : self . inner_dim . clone ( ) ,
808
815
inner_strides : self . inner_strides . clone ( ) ,
809
816
ptr : self . ptr ,
810
817
} ;
811
818
let right = AxisIterCore {
812
- index : 0 ,
813
- len : self . len - index ,
819
+ index : mid ,
820
+ end : self . end ,
814
821
stride : self . stride ,
815
822
inner_dim : self . inner_dim ,
816
823
inner_strides : self . inner_strides ,
817
- ptr : right_ptr ,
824
+ ptr : self . ptr ,
818
825
} ;
819
826
( left, right)
820
827
}
@@ -827,7 +834,7 @@ where
827
834
type Item = * mut A ;
828
835
829
836
fn next ( & mut self ) -> Option < Self :: Item > {
830
- if self . index >= self . len {
837
+ if self . index >= self . end {
831
838
None
832
839
} else {
833
840
let ptr = unsafe { self . offset ( self . index ) } ;
@@ -837,7 +844,7 @@ where
837
844
}
838
845
839
846
fn size_hint ( & self ) -> ( usize , Option < usize > ) {
840
- let len = self . len - self . index ;
847
+ let len = self . len ( ) ;
841
848
( len, Some ( len) )
842
849
}
843
850
}
@@ -847,16 +854,25 @@ where
847
854
D : Dimension ,
848
855
{
849
856
fn next_back ( & mut self ) -> Option < Self :: Item > {
850
- if self . index >= self . len {
857
+ if self . index >= self . end {
851
858
None
852
859
} else {
853
- self . len -= 1 ;
854
- let ptr = unsafe { self . offset ( self . len ) } ;
860
+ let ptr = unsafe { self . offset ( self . end - 1 ) } ;
861
+ self . end -= 1 ;
855
862
Some ( ptr)
856
863
}
857
864
}
858
865
}
859
866
867
+ impl < A , D > ExactSizeIterator for AxisIterCore < A , D >
868
+ where
869
+ D : Dimension ,
870
+ {
871
+ fn len ( & self ) -> usize {
872
+ self . end - self . index
873
+ }
874
+ }
875
+
860
876
/// An iterator that traverses over an axis and
861
877
/// and yields each subview.
862
878
///
@@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
899
915
}
900
916
}
901
917
902
- /// Split the iterator at index, yielding two disjoint iterators.
918
+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
903
919
///
904
- /// **Panics** if `index` is strictly greater than the iterator's length
920
+ /// `index` is relative to the current state of the iterator (which is not
921
+ /// necessarily the start of the axis).
922
+ ///
923
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
924
+ /// length.
905
925
pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
906
926
let ( left, right) = self . iter . split_at ( index) ;
907
927
(
@@ -946,7 +966,7 @@ where
946
966
D : Dimension ,
947
967
{
948
968
fn len ( & self ) -> usize {
949
- self . size_hint ( ) . 0
969
+ self . iter . len ( )
950
970
}
951
971
}
952
972
@@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
981
1001
}
982
1002
}
983
1003
984
- /// Split the iterator at index, yielding two disjoint iterators.
1004
+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
985
1005
///
986
- /// **Panics** if `index` is strictly greater than the iterator's length
1006
+ /// `index` is relative to the current state of the iterator (which is not
1007
+ /// necessarily the start of the axis).
1008
+ ///
1009
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
1010
+ /// length.
987
1011
pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
988
1012
let ( left, right) = self . iter . split_at ( index) ;
989
1013
(
@@ -1028,7 +1052,7 @@ where
1028
1052
D : Dimension ,
1029
1053
{
1030
1054
fn len ( & self ) -> usize {
1031
- self . size_hint ( ) . 0
1055
+ self . iter . len ( )
1032
1056
}
1033
1057
}
1034
1058
@@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
1048
1072
}
1049
1073
#[ doc( hidden) ]
1050
1074
fn as_ptr ( & self ) -> Self :: Ptr {
1051
- self . iter . ptr
1075
+ if self . len ( ) > 0 {
1076
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1077
+ // iterator remains (i.e. if `self.len() > 0`).
1078
+ unsafe { self . iter . offset ( self . iter . index ) }
1079
+ } else {
1080
+ // In this case, `self.iter.index` may be past the end, so we must
1081
+ // not call `.offset()`. It's okay to return a dangling pointer
1082
+ // because it will never be used in the length 0 case.
1083
+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1084
+ }
1052
1085
}
1053
1086
1054
1087
fn contiguous_stride ( & self ) -> isize {
@@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
1065
1098
}
1066
1099
#[ doc( hidden) ]
1067
1100
unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1068
- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1101
+ self . iter . offset ( self . iter . index + i[ 0 ] )
1069
1102
}
1070
1103
1071
1104
#[ doc( hidden) ]
@@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1096
1129
}
1097
1130
#[ doc( hidden) ]
1098
1131
fn as_ptr ( & self ) -> Self :: Ptr {
1099
- self . iter . ptr
1132
+ if self . len ( ) > 0 {
1133
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1134
+ // iterator remains (i.e. if `self.len() > 0`).
1135
+ unsafe { self . iter . offset ( self . iter . index ) }
1136
+ } else {
1137
+ // In this case, `self.iter.index` may be past the end, so we must
1138
+ // not call `.offset()`. It's okay to return a dangling pointer
1139
+ // because it will never be used in the length 0 case.
1140
+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1141
+ }
1100
1142
}
1101
1143
1102
1144
fn contiguous_stride ( & self ) -> isize {
@@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1113
1155
}
1114
1156
#[ doc( hidden) ]
1115
1157
unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1116
- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1158
+ self . iter . offset ( self . iter . index + i[ 0 ] )
1117
1159
}
1118
1160
1119
1161
#[ doc( hidden) ]
@@ -1164,21 +1206,28 @@ clone_bounds!(
1164
1206
///
1165
1207
/// Returns an axis iterator with the correct stride to move between chunks,
1166
1208
/// the number of chunks, and the shape of the last chunk.
1209
+ ///
1210
+ /// **Panics** if `size == 0`.
1167
1211
fn chunk_iter_parts < A , D : Dimension > (
1168
1212
v : ArrayView < ' _ , A , D > ,
1169
1213
axis : Axis ,
1170
1214
size : usize ,
1171
1215
) -> ( AxisIterCore < A , D > , usize , D ) {
1216
+ assert_ne ! ( size, 0 , "Chunk size must be nonzero." ) ;
1172
1217
let axis_len = v. len_of ( axis) ;
1173
- let size = if size > axis_len { axis_len } else { size } ;
1174
1218
let n_whole_chunks = axis_len / size;
1175
1219
let chunk_remainder = axis_len % size;
1176
1220
let iter_len = if chunk_remainder == 0 {
1177
1221
n_whole_chunks
1178
1222
} else {
1179
1223
n_whole_chunks + 1
1180
1224
} ;
1181
- let stride = v. stride_of ( axis) * size as isize ;
1225
+ let stride = if n_whole_chunks == 0 {
1226
+ // This case avoids potential overflow when `size > axis_len`.
1227
+ 0
1228
+ } else {
1229
+ v. stride_of ( axis) * size as isize
1230
+ } ;
1182
1231
1183
1232
let axis = axis. index ( ) ;
1184
1233
let mut inner_dim = v. dim . clone ( ) ;
@@ -1193,7 +1242,7 @@ fn chunk_iter_parts<A, D: Dimension>(
1193
1242
1194
1243
let iter = AxisIterCore {
1195
1244
index : 0 ,
1196
- len : iter_len,
1245
+ end : iter_len,
1197
1246
stride,
1198
1247
inner_dim,
1199
1248
inner_strides : v. strides ,
@@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
1270
1319
D : Dimension ,
1271
1320
{
1272
1321
fn next_back( & mut self ) -> Option <Self :: Item > {
1273
- let is_uneven = self . iter. len > self . n_whole_chunks;
1322
+ let is_uneven = self . iter. end > self . n_whole_chunks;
1274
1323
let res = self . iter. next_back( ) ;
1275
1324
self . get_subview( res, is_uneven)
1276
1325
}
0 commit comments