-
Notifications
You must be signed in to change notification settings - Fork 321
Add .split_at() methods for AxisChunksIter/Mut #691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -825,6 +825,19 @@ impl<A, D: Dimension> AxisIterCore<A, D> { | |
}; | ||
(left, right) | ||
} | ||
|
||
/// Does the same thing as `.next()` but also returns the index of the item | ||
/// relative to the start of the axis. | ||
fn next_with_index(&mut self) -> Option<(usize, *mut A)> { | ||
let index = self.index; | ||
self.next().map(|ptr| (index, ptr)) | ||
} | ||
|
||
/// Does the same thing as `.next_back()` but also returns the index of the | ||
/// item relative to the start of the axis. | ||
fn next_back_with_index(&mut self) -> Option<(usize, *mut A)> { | ||
self.next_back().map(|ptr| (self.end, ptr)) | ||
} | ||
} | ||
|
||
impl<A, D> Iterator for AxisIterCore<A, D> | ||
|
@@ -1182,9 +1195,13 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { | |
/// See [`.axis_chunks_iter()`](../struct.ArrayBase.html#method.axis_chunks_iter) for more information. | ||
pub struct AxisChunksIter<'a, A, D> { | ||
iter: AxisIterCore<A, D>, | ||
n_whole_chunks: usize, | ||
/// Dimension of the last (and possibly uneven) chunk | ||
last_dim: D, | ||
/// Index of the partial chunk (the chunk smaller than the specified chunk | ||
/// size due to the axis length not being evenly divisible). If the axis | ||
/// length is evenly divisible by the chunk size, this index is larger than | ||
/// the maximum valid index. | ||
partial_chunk_index: usize, | ||
/// Dimension of the partial chunk. | ||
partial_chunk_dim: D, | ||
life: PhantomData<&'a A>, | ||
} | ||
|
||
|
@@ -1193,10 +1210,10 @@ clone_bounds!( | |
AxisChunksIter['a, A, D] { | ||
@copy { | ||
life, | ||
n_whole_chunks, | ||
partial_chunk_index, | ||
} | ||
iter, | ||
last_dim, | ||
partial_chunk_dim, | ||
} | ||
); | ||
|
||
|
@@ -1233,12 +1250,9 @@ fn chunk_iter_parts<A, D: Dimension>( | |
let mut inner_dim = v.dim.clone(); | ||
inner_dim[axis] = size; | ||
|
||
let mut last_dim = v.dim; | ||
last_dim[axis] = if chunk_remainder == 0 { | ||
size | ||
} else { | ||
chunk_remainder | ||
}; | ||
let mut partial_chunk_dim = v.dim; | ||
partial_chunk_dim[axis] = chunk_remainder; | ||
let partial_chunk_index = n_whole_chunks; | ||
|
||
let iter = AxisIterCore { | ||
index: 0, | ||
|
@@ -1249,16 +1263,16 @@ fn chunk_iter_parts<A, D: Dimension>( | |
ptr: v.ptr, | ||
}; | ||
|
||
(iter, n_whole_chunks, last_dim) | ||
(iter, partial_chunk_index, partial_chunk_dim) | ||
} | ||
|
||
impl<'a, A, D: Dimension> AxisChunksIter<'a, A, D> { | ||
pub(crate) fn new(v: ArrayView<'a, A, D>, axis: Axis, size: usize) -> Self { | ||
let (iter, n_whole_chunks, last_dim) = chunk_iter_parts(v, axis, size); | ||
let (iter, partial_chunk_index, partial_chunk_dim) = chunk_iter_parts(v, axis, size); | ||
AxisChunksIter { | ||
iter, | ||
n_whole_chunks, | ||
last_dim, | ||
partial_chunk_index, | ||
partial_chunk_dim, | ||
life: PhantomData, | ||
} | ||
} | ||
|
@@ -1270,30 +1284,49 @@ macro_rules! chunk_iter_impl { | |
where | ||
D: Dimension, | ||
{ | ||
fn get_subview( | ||
&self, | ||
iter_item: Option<*mut A>, | ||
is_uneven: bool, | ||
) -> Option<$array<'a, A, D>> { | ||
iter_item.map(|ptr| { | ||
if !is_uneven { | ||
unsafe { | ||
$array::new_( | ||
ptr, | ||
self.iter.inner_dim.clone(), | ||
self.iter.inner_strides.clone(), | ||
) | ||
} | ||
} else { | ||
unsafe { | ||
$array::new_( | ||
ptr, | ||
self.last_dim.clone(), | ||
self.iter.inner_strides.clone(), | ||
) | ||
} | ||
fn get_subview(&self, index: usize, ptr: *mut A) -> $array<'a, A, D> { | ||
if index != self.partial_chunk_index { | ||
unsafe { | ||
$array::new_( | ||
ptr, | ||
self.iter.inner_dim.clone(), | ||
self.iter.inner_strides.clone(), | ||
) | ||
} | ||
} else { | ||
unsafe { | ||
$array::new_( | ||
ptr, | ||
self.partial_chunk_dim.clone(), | ||
self.iter.inner_strides.clone(), | ||
) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
/// Splits the iterator at index, yielding two disjoint iterators. | ||
/// | ||
/// `index` is relative to the current state of the iterator (which is not | ||
/// necessarily the start of the axis). | ||
/// | ||
/// **Panics** if `index` is strictly greater than the iterator's remaining | ||
/// length. | ||
pub fn split_at(self, index: usize) -> (Self, Self) { | ||
let (left, right) = self.iter.split_at(index); | ||
( | ||
Self { | ||
iter: left, | ||
partial_chunk_index: self.partial_chunk_index, | ||
partial_chunk_dim: self.partial_chunk_dim.clone(), | ||
life: self.life, | ||
}, | ||
Self { | ||
iter: right, | ||
partial_chunk_index: self.partial_chunk_index, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't read the whole code unfortunately (what's not visible in the diff) - why doesn't this partial_chunk_index require adjusting - the right part of the iter now starts at There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's an example: use ndarray::prelude::*;
fn main() {
let a: Array1<i32> = (0..13).collect();
let mut iter = a.axis_chunks_iter(Axis(0), 3);
iter.next(); // skip the first element so that we consider a partially-consumed iterator
println!("before_split = {:#?}", iter);
let (left, right) = iter.split_at(2);
println!("left = {:#?}", left);
println!("right = {:#?}", right);
} which gives the output
We can visualize the situation like this:
The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. If the only use of index is counting up to the partial_chunk_index, it makes total sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
partial_chunk_dim: self.partial_chunk_dim, | ||
life: self.life, | ||
}, | ||
) | ||
} | ||
} | ||
|
||
|
@@ -1304,9 +1337,9 @@ macro_rules! chunk_iter_impl { | |
type Item = $array<'a, A, D>; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
let res = self.iter.next(); | ||
let is_uneven = self.iter.index > self.n_whole_chunks; | ||
self.get_subview(res, is_uneven) | ||
self.iter | ||
.next_with_index() | ||
.map(|(index, ptr)| self.get_subview(index, ptr)) | ||
} | ||
|
||
fn size_hint(&self) -> (usize, Option<usize>) { | ||
|
@@ -1319,9 +1352,9 @@ macro_rules! chunk_iter_impl { | |
D: Dimension, | ||
{ | ||
fn next_back(&mut self) -> Option<Self::Item> { | ||
let is_uneven = self.iter.end > self.n_whole_chunks; | ||
let res = self.iter.next_back(); | ||
self.get_subview(res, is_uneven) | ||
self.iter | ||
.next_back_with_index() | ||
.map(|(index, ptr)| self.get_subview(index, ptr)) | ||
} | ||
} | ||
|
||
|
@@ -1342,18 +1375,19 @@ macro_rules! chunk_iter_impl { | |
/// for more information. | ||
pub struct AxisChunksIterMut<'a, A, D> { | ||
iter: AxisIterCore<A, D>, | ||
n_whole_chunks: usize, | ||
last_dim: D, | ||
partial_chunk_index: usize, | ||
partial_chunk_dim: D, | ||
life: PhantomData<&'a mut A>, | ||
} | ||
|
||
impl<'a, A, D: Dimension> AxisChunksIterMut<'a, A, D> { | ||
pub(crate) fn new(v: ArrayViewMut<'a, A, D>, axis: Axis, size: usize) -> Self { | ||
let (iter, len, last_dim) = chunk_iter_parts(v.into_view(), axis, size); | ||
let (iter, partial_chunk_index, partial_chunk_dim) = | ||
chunk_iter_parts(v.into_view(), axis, size); | ||
AxisChunksIterMut { | ||
iter, | ||
n_whole_chunks: len, | ||
last_dim, | ||
partial_chunk_index, | ||
partial_chunk_dim, | ||
life: PhantomData, | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be beneficial to rephrase this as an
Option
, to make it clearer that we might (or might not) have a partial chunk? Something along the lines of:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to use both the
Option
variant and the value ofpartial_chunk_index
to represent whether or not there's a partial chunk. (The biggest reason is that I prefer data structures where there's a single source of truth, rather than having to keep multiple things in sync. There might also be a small performance cost to accessingpartial_chunk_index
through theOption
(since accessing it requires checking whether theOption
is theSome
variant), but we'd need to test to determine if that would really be noticeable.) IMO, putting the fields in anOption
would be additional complication over the current approach without much benefit.It would be reasonable to eliminate
partial_chunk_index
and just use theOption
variant to represent the presence of a partial chunk, like this:or to always store the shape of the last chunk (regardless of whether or not it's a partial chunk):
These approaches have two disadvantages since they rely on checking whether the iterator is at its end to handle the partial chunk instead of checking whether the current index is equal to
partial_chunk_index
:.split_at()
needs to check whether or not the partial chunk is in the left piece and determinepartial_chunk
orlast_chunk_dim
of the left piece accordingly. (The partial chunk is in the left piece whenindex == self.iter.len()
.).next_back()
needs to setpartial_chunk
toNone
orlast_chunk_dim
toself.iter.inner_dim
each time it's called.So, I'd rather keep the current approach and add more comments if necessary to make it clear.