Skip to content

Commit 06b5026

Browse files
committed
FEAT: New method .move_into() for moving all array elements
.move_into() lets all elements move out of an Array, into an uninitialized array. We use a DropCounter to check duplication/drops of elements rigorously. The DropCounter code is taken from rayon collect tests, where I wrote it.
1 parent 1654125 commit 06b5026

File tree

4 files changed

+402
-36
lines changed

4 files changed

+402
-36
lines changed

src/impl_owned_array.rs

+164-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11

22
use alloc::vec::Vec;
3+
use std::mem::MaybeUninit;
34

45
use crate::imp_prelude::*;
6+
57
use crate::dimension;
68
use crate::error::{ErrorKind, ShapeError};
9+
use crate::iterators::Baseiter;
710
use crate::OwnedRepr;
811
use crate::Zip;
912

@@ -137,6 +140,125 @@ impl<A> Array<A, Ix2> {
137140
impl<A, D> Array<A, D>
138141
where D: Dimension
139142
{
143+
/// Move all elements from self into `new_array`, which must be of the same shape but
144+
/// can have a different memory layout. The destination is overwritten completely.
145+
///
146+
/// ***Panics*** if the shapes don't agree.
147+
pub fn move_into(mut self, new_array: ArrayViewMut<MaybeUninit<A>, D>) {
148+
unsafe {
149+
// Safety: copy_to_nonoverlapping cannot panic
150+
// Move all reachable elements
151+
Zip::from(self.raw_view_mut())
152+
.and(new_array)
153+
.for_each(|src, dst| {
154+
src.copy_to_nonoverlapping(dst.as_mut_ptr(), 1);
155+
});
156+
// Drop all unreachable elements
157+
self.drop_unreachable_elements();
158+
}
159+
}
160+
161+
/// This drops all "unreachable" elements in the data storage of self.
162+
///
163+
/// That means those elements that are not visible in the slicing of the array.
164+
/// *Reachable elements are assumed to already have been moved from.*
165+
///
166+
/// # Safety
167+
///
168+
/// This is a panic critical section since `self` is already moved-from.
169+
fn drop_unreachable_elements(mut self) -> OwnedRepr<A> {
170+
let self_len = self.len();
171+
172+
// "deconstruct" self; the owned repr releases ownership of all elements and we
173+
// and carry on with raw view methods
174+
let data_len = self.data.len();
175+
176+
let has_unreachable_elements = self_len != data_len;
177+
if !has_unreachable_elements || std::mem::size_of::<A>() == 0 {
178+
unsafe {
179+
self.data.set_len(0);
180+
}
181+
self.data
182+
} else {
183+
self.drop_unreachable_elements_slow()
184+
}
185+
}
186+
187+
#[inline(never)]
188+
#[cold]
189+
fn drop_unreachable_elements_slow(mut self) -> OwnedRepr<A> {
190+
// "deconstruct" self; the owned repr releases ownership of all elements and we
191+
// and carry on with raw view methods
192+
let self_len = self.len();
193+
let data_len = self.data.len();
194+
let data_ptr = self.data.as_nonnull_mut().as_ptr();
195+
196+
let mut self_;
197+
198+
unsafe {
199+
// Safety: self.data releases ownership of the elements
200+
self_ = self.raw_view_mut();
201+
self.data.set_len(0);
202+
}
203+
204+
205+
// uninvert axes where needed, so that stride > 0
206+
for i in 0..self_.ndim() {
207+
if self_.stride_of(Axis(i)) < 0 {
208+
self_.invert_axis(Axis(i));
209+
}
210+
}
211+
212+
// Sort axes to standard order, Axis(0) has biggest stride and Axis(n - 1) least stride
213+
// Note that self_ has holes, so self_ is not C-contiguous
214+
sort_axes_in_default_order(&mut self_);
215+
216+
unsafe {
217+
let array_memory_head_ptr = self_.ptr.as_ptr();
218+
let data_end_ptr = data_ptr.add(data_len);
219+
debug_assert!(data_ptr <= array_memory_head_ptr);
220+
debug_assert!(array_memory_head_ptr <= data_end_ptr);
221+
222+
// iter is a raw pointer iterator traversing self_ in its standard order
223+
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
224+
let mut dropped_elements = 0;
225+
226+
// The idea is simply this: the iterator will yield the elements of self_ in
227+
// increasing address order.
228+
//
229+
// The pointers produced by the iterator are those that we *do not* touch.
230+
// The pointers *not mentioned* by the iterator are those we have to drop.
231+
//
232+
// We have to drop elements in the range from `data_ptr` until (not including)
233+
// `data_end_ptr`, except those that are produced by `iter`.
234+
let mut last_ptr = data_ptr;
235+
236+
while let Some(elem_ptr) = iter.next() {
237+
// The interval from last_ptr up until (not including) elem_ptr
238+
// should now be dropped. This interval may be empty, then we just skip this loop.
239+
while last_ptr != elem_ptr {
240+
debug_assert!(last_ptr < data_end_ptr);
241+
std::ptr::drop_in_place(last_ptr as *mut A);
242+
last_ptr = last_ptr.add(1);
243+
dropped_elements += 1;
244+
}
245+
// Next interval will continue one past the current element
246+
last_ptr = elem_ptr.add(1);
247+
}
248+
249+
while last_ptr < data_end_ptr {
250+
std::ptr::drop_in_place(last_ptr as *mut A);
251+
last_ptr = last_ptr.add(1);
252+
dropped_elements += 1;
253+
}
254+
255+
assert_eq!(data_len, dropped_elements + self_len,
256+
"Internal error: inconsistency in move_into");
257+
}
258+
self.data
259+
}
260+
261+
140262
/// Append an array to the array
141263
///
142264
/// The axis-to-append-to `axis` must be the array's "growing axis" for this operation
@@ -313,7 +435,7 @@ impl<A, D> Array<A, D>
313435
array.invert_axis(Axis(i));
314436
}
315437
}
316-
sort_axes_to_standard_order(&mut tail_view, &mut array);
438+
sort_axes_to_standard_order_tandem(&mut tail_view, &mut array);
317439
}
318440
Zip::from(tail_view).and(array)
319441
.debug_assert_c_order()
@@ -336,7 +458,21 @@ impl<A, D> Array<A, D>
336458
}
337459
}
338460

339-
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
461+
/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
462+
///
463+
/// The axes should have stride >= 0 before calling this method.
464+
fn sort_axes_in_default_order<S, D>(a: &mut ArrayBase<S, D>)
465+
where
466+
S: RawData,
467+
D: Dimension,
468+
{
469+
if a.ndim() <= 1 {
470+
return;
471+
}
472+
sort_axes1_impl(&mut a.dim, &mut a.strides);
473+
}
474+
475+
fn sort_axes_to_standard_order_tandem<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
340476
where
341477
S: RawData,
342478
S2: RawData,
@@ -350,6 +486,32 @@ where
350486
a.shape(), a.strides());
351487
}
352488

489+
fn sort_axes1_impl<D>(adim: &mut D, astrides: &mut D)
490+
where
491+
D: Dimension,
492+
{
493+
debug_assert!(adim.ndim() > 1);
494+
debug_assert_eq!(adim.ndim(), astrides.ndim());
495+
// bubble sort axes
496+
let mut changed = true;
497+
while changed {
498+
changed = false;
499+
for i in 0..adim.ndim() - 1 {
500+
let axis_i = i;
501+
let next_axis = i + 1;
502+
503+
// make sure higher stride axes sort before.
504+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
505+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
506+
changed = true;
507+
adim.slice_mut().swap(axis_i, next_axis);
508+
astrides.slice_mut().swap(axis_i, next_axis);
509+
}
510+
}
511+
}
512+
}
513+
514+
353515
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
354516
where
355517
D: Dimension,

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
clippy::deref_addrof,
1313
clippy::unreadable_literal,
1414
clippy::manual_map, // is not an error
15+
clippy::while_let_on_iterator, // is not an error
1516
)]
1617
#![cfg_attr(not(feature = "std"), no_std)]
1718

0 commit comments

Comments
 (0)