Skip to content

Commit 5766f4b

Browse files
committed
intoiter: Implement by-value iterator for owned arrays
1 parent 297a5a4 commit 5766f4b

File tree

3 files changed

+230
-1
lines changed

3 files changed

+230
-1
lines changed

src/iterators/into_iter.rs

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright 2020-2021 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use std::mem;
10+
use std::ptr::NonNull;
11+
12+
use crate::imp_prelude::*;
13+
use crate::OwnedRepr;
14+
15+
use super::Baseiter;
16+
use crate::impl_owned_array::drop_unreachable_raw;
17+
18+
19+
/// By-value iterator for an array
20+
pub struct IntoIter<A, D>
21+
where
22+
D: Dimension,
23+
{
24+
array_data: OwnedRepr<A>,
25+
inner: Baseiter<A, D>,
26+
data_len: usize,
27+
/// first memory address of an array element
28+
array_head_ptr: NonNull<A>,
29+
// if true, the array owns elements that are not reachable by indexing
30+
// through all the indices of the dimension.
31+
has_unreachable_elements: bool,
32+
}
33+
34+
impl<A, D> IntoIter<A, D>
35+
where
36+
D: Dimension,
37+
{
38+
/// Create a new by-value iterator that consumes `array`
39+
pub(crate) fn new(mut array: Array<A, D>) -> Self {
40+
unsafe {
41+
let array_head_ptr = array.ptr;
42+
let ptr = array.as_mut_ptr();
43+
let mut array_data = array.data;
44+
let data_len = array_data.release_all_elements();
45+
debug_assert!(data_len >= array.dim.size());
46+
let has_unreachable_elements = array.dim.size() != data_len;
47+
let inner = Baseiter::new(ptr, array.dim, array.strides);
48+
49+
IntoIter {
50+
array_data,
51+
inner,
52+
data_len,
53+
array_head_ptr,
54+
has_unreachable_elements,
55+
}
56+
}
57+
}
58+
}
59+
60+
impl<A, D: Dimension> Iterator for IntoIter<A, D> {
61+
type Item = A;
62+
63+
#[inline]
64+
fn next(&mut self) -> Option<A> {
65+
self.inner.next().map(|p| unsafe { p.read() })
66+
}
67+
68+
fn size_hint(&self) -> (usize, Option<usize>) {
69+
self.inner.size_hint()
70+
}
71+
}
72+
73+
impl<A, D: Dimension> ExactSizeIterator for IntoIter<A, D> {
74+
fn len(&self) -> usize { self.inner.len() }
75+
}
76+
77+
impl<A, D> Drop for IntoIter<A, D>
78+
where
79+
D: Dimension
80+
{
81+
fn drop(&mut self) {
82+
if !self.has_unreachable_elements || mem::size_of::<A>() == 0 || !mem::needs_drop::<A>() {
83+
return;
84+
}
85+
86+
// iterate til the end
87+
while let Some(_) = self.next() { }
88+
89+
unsafe {
90+
let data_ptr = self.array_data.as_ptr_mut();
91+
let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(),
92+
self.inner.strides.clone());
93+
debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}",
94+
self.data_len, self.inner.dim.size());
95+
drop_unreachable_raw(view, data_ptr, self.data_len);
96+
}
97+
}
98+
}
99+
100+
impl<A, D> IntoIterator for Array<A, D>
101+
where
102+
D: Dimension
103+
{
104+
type Item = A;
105+
type IntoIter = IntoIter<A, D>;
106+
107+
fn into_iter(self) -> Self::IntoIter {
108+
IntoIter::new(self)
109+
}
110+
}
111+
112+
impl<A, D> IntoIterator for ArcArray<A, D>
113+
where
114+
D: Dimension,
115+
A: Clone,
116+
{
117+
type Item = A;
118+
type IntoIter = IntoIter<A, D>;
119+
120+
fn into_iter(self) -> Self::IntoIter {
121+
IntoIter::new(self.into_owned())
122+
}
123+
}
124+
125+
impl<A, D> IntoIterator for CowArray<'_, A, D>
126+
where
127+
D: Dimension,
128+
A: Clone,
129+
{
130+
type Item = A;
131+
type IntoIter = IntoIter<A, D>;
132+
133+
fn into_iter(self) -> Self::IntoIter {
134+
IntoIter::new(self.into_owned())
135+
}
136+
}

src/iterators/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#[macro_use]
1010
mod macros;
1111
mod chunks;
12+
mod into_iter;
1213
pub mod iter;
1314
mod lanes;
1415
mod windows;
@@ -26,6 +27,7 @@ use super::{Dimension, Ix, Ixs};
2627
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2728
pub use self::lanes::{Lanes, LanesMut};
2829
pub use self::windows::Windows;
30+
pub use self::into_iter::IntoIter;
2931

3032
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3133

@@ -1465,6 +1467,7 @@ unsafe impl TrustedIterator for ::std::ops::Range<usize> {}
14651467
// FIXME: These indices iter are dubious -- size needs to be checked up front.
14661468
unsafe impl<D> TrustedIterator for IndicesIter<D> where D: Dimension {}
14671469
unsafe impl<D> TrustedIterator for IndicesIterF<D> where D: Dimension {}
1470+
unsafe impl<A, D> TrustedIterator for IntoIter<A, D> where D: Dimension {}
14681471

14691472
/// Like Iterator::collect, but only for trusted length iterators
14701473
pub fn to_vec<I>(iter: I) -> Vec<I::Item>

tests/iterators.rs

+91-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
use ndarray::prelude::*;
99
use ndarray::{arr3, aview1, indices, s, Axis, Slice, Zip};
1010

11-
use itertools::{assert_equal, enumerate};
11+
use itertools::assert_equal;
12+
use itertools::enumerate;
13+
use std::cell::Cell;
1214

1315
macro_rules! assert_panics {
1416
($body:expr) => {
@@ -892,3 +894,91 @@ fn test_rfold() {
892894
);
893895
}
894896
}
897+
898+
#[test]
899+
fn test_into_iter() {
900+
let a = Array1::from(vec![1, 2, 3, 4]);
901+
let v = a.into_iter().collect::<Vec<_>>();
902+
assert_eq!(v, [1, 2, 3, 4]);
903+
}
904+
905+
#[test]
906+
fn test_into_iter_2d() {
907+
let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap();
908+
let v = a.into_iter().collect::<Vec<_>>();
909+
assert_eq!(v, [1, 2, 3, 4]);
910+
911+
let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap().reversed_axes();
912+
let v = a.into_iter().collect::<Vec<_>>();
913+
assert_eq!(v, [1, 3, 2, 4]);
914+
}
915+
916+
#[test]
917+
fn test_into_iter_sliced() {
918+
let (m, n) = (4, 5);
919+
let drops = Cell::new(0);
920+
921+
for i in 0..m - 1 {
922+
for j in 0..n - 1 {
923+
for i2 in i + 1 .. m {
924+
for j2 in j + 1 .. n {
925+
for invert in 0..3 {
926+
drops.set(0);
927+
let i = i as isize;
928+
let j = j as isize;
929+
let i2 = i2 as isize;
930+
let j2 = j2 as isize;
931+
let mut a = Array1::from_iter(0..(m * n) as i32)
932+
.mapv(|v| DropCount::new(v, &drops))
933+
.into_shape((m, n)).unwrap();
934+
a.slice_collapse(s![i..i2, j..j2]);
935+
if invert < a.ndim() {
936+
a.invert_axis(Axis(invert));
937+
}
938+
939+
println!("{:?}, {:?}", i..i2, j..j2);
940+
println!("{:?}", a);
941+
let answer = a.iter().cloned().collect::<Vec<_>>();
942+
let v = a.into_iter().collect::<Vec<_>>();
943+
assert_eq!(v, answer);
944+
945+
assert_eq!(drops.get(), m * n - v.len());
946+
drop(v);
947+
assert_eq!(drops.get(), m * n);
948+
}
949+
}
950+
}
951+
}
952+
}
953+
}
954+
955+
/// Helper struct that counts its drops Asserts that it's not dropped twice. Also global number of
956+
/// drops is counted in the cell.
957+
///
958+
/// Compares equal by its "represented value".
959+
#[derive(Clone, Debug)]
960+
struct DropCount<'a> {
961+
value: i32,
962+
my_drops: usize,
963+
drops: &'a Cell<usize>
964+
}
965+
966+
impl PartialEq for DropCount<'_> {
967+
fn eq(&self, other: &Self) -> bool {
968+
self.value == other.value
969+
}
970+
}
971+
972+
impl<'a> DropCount<'a> {
973+
fn new(value: i32, drops: &'a Cell<usize>) -> Self {
974+
DropCount { value, my_drops: 0, drops }
975+
}
976+
}
977+
978+
impl Drop for DropCount<'_> {
979+
fn drop(&mut self) {
980+
assert_eq!(self.my_drops, 0);
981+
self.my_drops += 1;
982+
self.drops.set(self.drops.get() + 1);
983+
}
984+
}

0 commit comments

Comments
 (0)