Skip to content

Commit 9f868f7

Browse files
authored
Merge pull request #986 from rust-ndarray/intoiterator
Implement by-value iterator for owned arrays
2 parents 4e31d2f + 5766f4b commit 9f868f7

File tree

9 files changed

+395
-128
lines changed

9 files changed

+395
-128
lines changed

src/data_repr.rs

+11
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ impl<A> OwnedRepr<A> {
5353
self.ptr.as_ptr()
5454
}
5555

56+
pub(crate) fn as_ptr_mut(&self) -> *mut A {
57+
self.ptr.as_ptr()
58+
}
59+
5660
pub(crate) fn as_nonnull_mut(&mut self) -> NonNull<A> {
5761
self.ptr
5862
}
@@ -88,6 +92,13 @@ impl<A> OwnedRepr<A> {
8892
self.len = new_len;
8993
}
9094

95+
/// Return the length (number of elements in total)
96+
pub(crate) fn release_all_elements(&mut self) -> usize {
97+
let ret = self.len;
98+
self.len = 0;
99+
ret
100+
}
101+
91102
/// Cast self into equivalent repr of other element type
92103
///
93104
/// ## Safety

src/impl_constructors.rs

+22
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use crate::indices;
2929
#[cfg(feature = "std")]
3030
use crate::iterators::to_vec;
3131
use crate::iterators::to_vec_mapped;
32+
use crate::iterators::TrustedIterator;
3233
use crate::StrideShape;
3334
#[cfg(feature = "std")]
3435
use crate::{geomspace, linspace, logspace};
@@ -495,6 +496,27 @@ where
495496
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
496497
}
497498

499+
/// Creates an array from an iterator, mapped by `map` and interpret it according to the
500+
/// provided shape and strides.
501+
///
502+
/// # Safety
503+
///
504+
/// See from_shape_vec_unchecked
505+
pub(crate) unsafe fn from_shape_trusted_iter_unchecked<Sh, I, F>(shape: Sh, iter: I, map: F)
506+
-> Self
507+
where
508+
Sh: Into<StrideShape<D>>,
509+
I: TrustedIterator + ExactSizeIterator,
510+
F: FnMut(I::Item) -> A,
511+
{
512+
let shape = shape.into();
513+
let dim = shape.dim;
514+
let strides = shape.strides.strides_for_dim(&dim);
515+
let v = to_vec_mapped(iter, map);
516+
Self::from_vec_dim_stride_unchecked(dim, strides, v)
517+
}
518+
519+
498520
/// Create an array with uninitalized elements, shape `shape`.
499521
///
500522
/// The uninitialized elements of type `A` are represented by the type `MaybeUninit<A>`,

src/impl_methods.rs

+11-13
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ where
498498
///
499499
/// **Panics** if an index is out of bounds or step size is zero.<br>
500500
/// **Panics** if `axis` is out of bounds.
501+
#[must_use = "slice_axis returns an array view with the sliced result"]
501502
pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D>
502503
where
503504
S: Data,
@@ -511,6 +512,7 @@ where
511512
///
512513
/// **Panics** if an index is out of bounds or step size is zero.<br>
513514
/// **Panics** if `axis` is out of bounds.
515+
#[must_use = "slice_axis_mut returns an array view with the sliced result"]
514516
pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D>
515517
where
516518
S: DataMut,
@@ -2224,17 +2226,14 @@ where
22242226
A: 'a,
22252227
S: Data,
22262228
{
2227-
if let Some(slc) = self.as_slice_memory_order() {
2228-
let v = crate::iterators::to_vec_mapped(slc.iter(), f);
2229-
unsafe {
2230-
ArrayBase::from_shape_vec_unchecked(
2229+
unsafe {
2230+
if let Some(slc) = self.as_slice_memory_order() {
2231+
ArrayBase::from_shape_trusted_iter_unchecked(
22312232
self.dim.clone().strides(self.strides.clone()),
2232-
v,
2233-
)
2233+
slc.iter(), f)
2234+
} else {
2235+
ArrayBase::from_shape_trusted_iter_unchecked(self.dim.clone(), self.iter(), f)
22342236
}
2235-
} else {
2236-
let v = crate::iterators::to_vec_mapped(self.iter(), f);
2237-
unsafe { ArrayBase::from_shape_vec_unchecked(self.dim.clone(), v) }
22382237
}
22392238
}
22402239

@@ -2254,11 +2253,10 @@ where
22542253
if self.is_contiguous() {
22552254
let strides = self.strides.clone();
22562255
let slc = self.as_slice_memory_order_mut().unwrap();
2257-
let v = crate::iterators::to_vec_mapped(slc.iter_mut(), f);
2258-
unsafe { ArrayBase::from_shape_vec_unchecked(dim.strides(strides), v) }
2256+
unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim.strides(strides),
2257+
slc.iter_mut(), f) }
22592258
} else {
2260-
let v = crate::iterators::to_vec_mapped(self.iter_mut(), f);
2261-
unsafe { ArrayBase::from_shape_vec_unchecked(dim, v) }
2259+
unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim, self.iter_mut(), f) }
22622260
}
22632261
}
22642262

src/impl_owned_array.rs

+78-73
Original file line numberDiff line numberDiff line change
@@ -223,89 +223,18 @@ impl<A, D> Array<A, D>
223223
fn drop_unreachable_elements_slow(mut self) -> OwnedRepr<A> {
224224
// "deconstruct" self; the owned repr releases ownership of all elements and we
225225
// carry on with raw view methods
226-
let self_len = self.len();
227226
let data_len = self.data.len();
228227
let data_ptr = self.data.as_nonnull_mut().as_ptr();
229228

230-
let mut self_;
231-
232229
unsafe {
233230
// Safety: self.data releases ownership of the elements. Any panics below this point
234231
// will result in leaking elements instead of double drops.
235-
self_ = self.raw_view_mut();
232+
let self_ = self.raw_view_mut();
236233
self.data.set_len(0);
237-
}
238234

239-
240-
// uninvert axes where needed, so that stride > 0
241-
for i in 0..self_.ndim() {
242-
if self_.stride_of(Axis(i)) < 0 {
243-
self_.invert_axis(Axis(i));
244-
}
235+
drop_unreachable_raw(self_, data_ptr, data_len);
245236
}
246237

247-
// Sort axes to standard order, Axis(0) has biggest stride and Axis(n - 1) least stride
248-
// Note that self_ has holes, so self_ is not C-contiguous
249-
sort_axes_in_default_order(&mut self_);
250-
251-
unsafe {
252-
// with uninverted axes this is now the element with lowest address
253-
let array_memory_head_ptr = self_.ptr.as_ptr();
254-
let data_end_ptr = data_ptr.add(data_len);
255-
debug_assert!(data_ptr <= array_memory_head_ptr);
256-
debug_assert!(array_memory_head_ptr <= data_end_ptr);
257-
258-
// The idea is simply this: the iterator will yield the elements of self_ in
259-
// increasing address order.
260-
//
261-
// The pointers produced by the iterator are those that we *do not* touch.
262-
// The pointers *not mentioned* by the iterator are those we have to drop.
263-
//
264-
// We have to drop elements in the range from `data_ptr` until (not including)
265-
// `data_end_ptr`, except those that are produced by `iter`.
266-
267-
// As an optimization, the innermost axis is removed if it has stride 1, because
268-
// we then have a long stretch of contiguous elements we can skip as one.
269-
let inner_lane_len;
270-
if self_.ndim() > 1 && self_.strides.last_elem() == 1 {
271-
self_.dim.slice_mut().rotate_right(1);
272-
self_.strides.slice_mut().rotate_right(1);
273-
inner_lane_len = self_.dim[0];
274-
self_.dim[0] = 1;
275-
self_.strides[0] = 1;
276-
} else {
277-
inner_lane_len = 1;
278-
}
279-
280-
// iter is a raw pointer iterator traversing the array in memory order now with the
281-
// sorted axes.
282-
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
283-
let mut dropped_elements = 0;
284-
285-
let mut last_ptr = data_ptr;
286-
287-
while let Some(elem_ptr) = iter.next() {
288-
// The interval from last_ptr up until (not including) elem_ptr
289-
// should now be dropped. This interval may be empty, then we just skip this loop.
290-
while last_ptr != elem_ptr {
291-
debug_assert!(last_ptr < data_end_ptr);
292-
std::ptr::drop_in_place(last_ptr);
293-
last_ptr = last_ptr.add(1);
294-
dropped_elements += 1;
295-
}
296-
// Next interval will continue one past the current lane
297-
last_ptr = elem_ptr.add(inner_lane_len);
298-
}
299-
300-
while last_ptr < data_end_ptr {
301-
std::ptr::drop_in_place(last_ptr);
302-
last_ptr = last_ptr.add(1);
303-
dropped_elements += 1;
304-
}
305-
306-
assert_eq!(data_len, dropped_elements + self_len,
307-
"Internal error: inconsistency in move_into");
308-
}
309238
self.data
310239
}
311240

@@ -594,6 +523,82 @@ impl<A, D> Array<A, D>
594523
}
595524
}
596525

526+
/// This drops all "unreachable" elements in `self_` given the data pointer and data length.
527+
///
528+
/// # Safety
529+
///
530+
/// This is an internal function for use by move_into and IntoIter only, safety invariants may need
531+
/// to be upheld across the calls from those implementations.
532+
pub(crate) unsafe fn drop_unreachable_raw<A, D>(mut self_: RawArrayViewMut<A, D>, data_ptr: *mut A, data_len: usize)
533+
where
534+
D: Dimension,
535+
{
536+
let self_len = self_.len();
537+
538+
for i in 0..self_.ndim() {
539+
if self_.stride_of(Axis(i)) < 0 {
540+
self_.invert_axis(Axis(i));
541+
}
542+
}
543+
sort_axes_in_default_order(&mut self_);
544+
// with uninverted axes this is now the element with lowest address
545+
let array_memory_head_ptr = self_.ptr.as_ptr();
546+
let data_end_ptr = data_ptr.add(data_len);
547+
debug_assert!(data_ptr <= array_memory_head_ptr);
548+
debug_assert!(array_memory_head_ptr <= data_end_ptr);
549+
550+
// The idea is simply this: the iterator will yield the elements of self_ in
551+
// increasing address order.
552+
//
553+
// The pointers produced by the iterator are those that we *do not* touch.
554+
// The pointers *not mentioned* by the iterator are those we have to drop.
555+
//
556+
// We have to drop elements in the range from `data_ptr` until (not including)
557+
// `data_end_ptr`, except those that are produced by `iter`.
558+
559+
// As an optimization, the innermost axis is removed if it has stride 1, because
560+
// we then have a long stretch of contiguous elements we can skip as one.
561+
let inner_lane_len;
562+
if self_.ndim() > 1 && self_.strides.last_elem() == 1 {
563+
self_.dim.slice_mut().rotate_right(1);
564+
self_.strides.slice_mut().rotate_right(1);
565+
inner_lane_len = self_.dim[0];
566+
self_.dim[0] = 1;
567+
self_.strides[0] = 1;
568+
} else {
569+
inner_lane_len = 1;
570+
}
571+
572+
// iter is a raw pointer iterator traversing the array in memory order now with the
573+
// sorted axes.
574+
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
575+
let mut dropped_elements = 0;
576+
577+
let mut last_ptr = data_ptr;
578+
579+
while let Some(elem_ptr) = iter.next() {
580+
// The interval from last_ptr up until (not including) elem_ptr
581+
// should now be dropped. This interval may be empty, then we just skip this loop.
582+
while last_ptr != elem_ptr {
583+
debug_assert!(last_ptr < data_end_ptr);
584+
std::ptr::drop_in_place(last_ptr);
585+
last_ptr = last_ptr.add(1);
586+
dropped_elements += 1;
587+
}
588+
// Next interval will continue one past the current lane
589+
last_ptr = elem_ptr.add(inner_lane_len);
590+
}
591+
592+
while last_ptr < data_end_ptr {
593+
std::ptr::drop_in_place(last_ptr);
594+
last_ptr = last_ptr.add(1);
595+
dropped_elements += 1;
596+
}
597+
598+
assert_eq!(data_len, dropped_elements + self_len,
599+
"Internal error: inconsistency in move_into");
600+
}
601+
597602
/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
598603
///
599604
/// The axes should have stride >= 0 before calling this method.

0 commit comments

Comments
 (0)