Skip to content

Commit f592bee

Browse files
committed
WIP: .move_into() and better .try_append_array()
FEAT: automatic layout fix in append_to_array (doc not yet updated) FEAT: New method .move_into() for moving all array elements TEST: Add tests for .move_into() We use a DropCounter to check duplication/drops of elements rigorously. The DropCounter code is taken from rayon collect tests, where I wrote it.
1 parent 1654125 commit f592bee

File tree

4 files changed

+474
-45
lines changed

4 files changed

+474
-45
lines changed

src/impl_owned_array.rs

+215-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11

22
use alloc::vec::Vec;
3+
use std::mem::MaybeUninit;
34

45
use crate::imp_prelude::*;
6+
57
use crate::dimension;
68
use crate::error::{ErrorKind, ShapeError};
9+
use crate::iterators::Baseiter;
710
use crate::OwnedRepr;
811
use crate::Zip;
912

@@ -137,6 +140,166 @@ impl<A> Array<A, Ix2> {
137140
impl<A, D> Array<A, D>
138141
where D: Dimension
139142
{
143+
/// Move all elements from self into `new_array`, which must be of the same shape but
144+
/// can have a different memory layout. The destination is overwritten completely.
145+
///
146+
/// ***Panics*** if the shapes don't agree.
147+
pub fn move_into(mut self, new_array: &mut Array<MaybeUninit<A>, D>) {
148+
unsafe {
149+
// Safety: copy_to_nonoverlapping cannot panic
150+
// Move all reachable elements
151+
Zip::from(self.raw_view_mut()).and(new_array)
152+
.for_each(|src, dst| {
153+
src.copy_to_nonoverlapping(dst.as_mut_ptr(), 1);
154+
});
155+
// Drop all unreachable elements
156+
self.drop_unreachable_elements();
157+
}
158+
}
159+
160+
/// This drops all "unreachable" elements in the data storage of self.
161+
///
162+
/// That means those elements that are not visible in the slicing of the array.
163+
/// *Reachable elements are assumed to already have been moved from.*
164+
///
165+
/// # Safety
166+
///
167+
/// This is a panic critical section since `self` is already moved-from.
168+
fn drop_unreachable_elements(mut self) -> OwnedRepr<A> {
169+
let self_len = self.len();
170+
171+
// "deconstruct" self; the owned repr releases ownership of all elements and we
172+
// and carry on with raw view methods
173+
let data_len = self.data.len();
174+
175+
let has_unreachable_elements = self_len != data_len;
176+
if !has_unreachable_elements || std::mem::size_of::<A>() == 0 {
177+
unsafe {
178+
self.data.set_len(0);
179+
}
180+
self.data
181+
} else {
182+
self.drop_unreachable_elements_slow()
183+
}
184+
}
185+
186+
#[inline(never)]
187+
#[cold]
188+
fn drop_unreachable_elements_slow(mut self) -> OwnedRepr<A> {
189+
// "deconstruct" self; the owned repr releases ownership of all elements and we
190+
// and carry on with raw view methods
191+
let self_len = self.len();
192+
let data_len = self.data.len();
193+
let data_ptr = self.data.as_nonnull_mut().as_ptr();
194+
195+
let mut self_;
196+
197+
unsafe {
198+
// Safety: self.data releases ownership of the elements
199+
self_ = self.raw_view_mut();
200+
self.data.set_len(0);
201+
}
202+
203+
204+
// uninvert axes where needed, so that stride > 0
205+
for i in 0..self_.ndim() {
206+
if self_.stride_of(Axis(i)) < 0 {
207+
self_.invert_axis(Axis(i));
208+
}
209+
}
210+
211+
// Sort axes to standard order, Axis(0) has biggest stride and Axis(n - 1) least stride
212+
// Note that self_ has holes, so self_ is not C-contiguous
213+
sort_axes_in_default_order(&mut self_);
214+
215+
unsafe {
216+
let array_memory_head_ptr = self_.ptr.as_ptr();
217+
let data_end_ptr = data_ptr.add(data_len);
218+
debug_assert!(data_ptr <= array_memory_head_ptr);
219+
debug_assert!(array_memory_head_ptr <= data_end_ptr);
220+
221+
// iter is a raw pointer iterator traversing self_ in its standard order
222+
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
223+
let mut dropped_elements = 0;
224+
225+
// The idea is simply this: the iterator will yield the elements of self_ in
226+
// increasing address order.
227+
//
228+
// The pointers produced by the iterator are those that we *do not* touch.
229+
// The pointers *not mentioned* by the iterator are those we have to drop.
230+
//
231+
// We have to drop elements in the range from `data_ptr` until (not including)
232+
// `data_end_ptr`, except those that are produced by `iter`.
233+
let mut last_ptr = data_ptr;
234+
235+
while let Some(elem_ptr) = iter.next() {
236+
// The interval from last_ptr up until (not including) elem_ptr
237+
// should now be dropped. This interval may be empty, then we just skip this loop.
238+
while last_ptr != elem_ptr {
239+
debug_assert!(last_ptr < data_end_ptr);
240+
std::ptr::drop_in_place(last_ptr as *mut A);
241+
last_ptr = last_ptr.add(1);
242+
dropped_elements += 1;
243+
}
244+
// Next interval will continue one past the current element
245+
last_ptr = elem_ptr.add(1);
246+
}
247+
248+
while last_ptr < data_end_ptr {
249+
std::ptr::drop_in_place(last_ptr as *mut A);
250+
last_ptr = last_ptr.add(1);
251+
dropped_elements += 1;
252+
}
253+
254+
assert_eq!(data_len, dropped_elements + self_len,
255+
"Internal error: inconsistency in move_into");
256+
}
257+
self.data
258+
}
259+
260+
/// Create an empty array with dimension `ndim` and all zeros shape
261+
///
262+
/// ***Panics*** if ndim is 0 or D is zero-dim
263+
pub(crate) fn empty(ndim: usize) -> Array<A, D> {
264+
assert_ne!(ndim, 0);
265+
assert_ne!(D::NDIM, Some(0));
266+
unsafe {
267+
// Safety: all elements (zero elements) are initialized
268+
Array::uninit(D::zeros(ndim)).assume_init()
269+
}
270+
}
271+
272+
/// Create new_array with the right layout for appending to `growing_axis`
273+
#[inline(never)]
274+
fn change_to_contiguous_layout(&mut self, growing_axis: Axis) {
275+
let ndim = self.ndim();
276+
let mut dim = self.raw_dim();
277+
278+
// The array will be created with 0 (C) or ndim-1 (F) as the biggest stride
279+
// axis. Rearrange the shape so that `growing_axis` is the biggest stride axis
280+
// afterwards.
281+
let prefer_f_layout = growing_axis == Axis(ndim - 1);
282+
if !prefer_f_layout {
283+
dim.slice_mut().swap(0, growing_axis.index());
284+
}
285+
let mut new_array = Self::uninit(dim.set_f(prefer_f_layout));
286+
if !prefer_f_layout {
287+
new_array.swap_axes(0, growing_axis.index());
288+
}
289+
290+
// self -> old_self.
291+
// dummy array -> self.
292+
// old_self elements are moved -> new_array.
293+
let old_self = std::mem::replace(self, Self::empty(ndim));
294+
old_self.move_into(&mut new_array);
295+
296+
// new_array -> self.
297+
unsafe {
298+
*self = new_array.assume_init();
299+
}
300+
}
301+
302+
140303
/// Append an array to the array
141304
///
142305
/// The axis-to-append-to `axis` must be the array's "growing axis" for this operation
@@ -217,19 +380,27 @@ impl<A, D> Array<A, D>
217380
}
218381

219382
let self_is_empty = self.is_empty();
383+
let mut incompatible_layout = false;
220384

221385
// array must be empty or have `axis` as the outermost (longest stride) axis
222386
if !self_is_empty && current_axis_len > 1 {
223387
// `axis` must be max stride axis or equal to its stride
224388
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
225389
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
226-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
390+
incompatible_layout = true;
227391
}
228392
}
229393

230394
// array must be be "full" (have no exterior holes)
231395
if self.len() != self.data.len() {
232-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
396+
incompatible_layout = true;
397+
}
398+
399+
if incompatible_layout {
400+
self.change_to_contiguous_layout(axis);
401+
// safety-check parameters after remodeling
402+
debug_assert_eq!(self_is_empty, self.is_empty());
403+
debug_assert_eq!(current_axis_len, self.len_of(axis));
233404
}
234405

235406
let strides = if self_is_empty {
@@ -313,7 +484,7 @@ impl<A, D> Array<A, D>
313484
array.invert_axis(Axis(i));
314485
}
315486
}
316-
sort_axes_to_standard_order(&mut tail_view, &mut array);
487+
sort_axes_to_standard_order_tandem(&mut tail_view, &mut array);
317488
}
318489
Zip::from(tail_view).and(array)
319490
.debug_assert_c_order()
@@ -336,7 +507,21 @@ impl<A, D> Array<A, D>
336507
}
337508
}
338509

339-
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
510+
/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
511+
///
512+
/// The axes should have stride >= 0 before calling this method.
513+
fn sort_axes_in_default_order<S, D>(a: &mut ArrayBase<S, D>)
514+
where
515+
S: RawData,
516+
D: Dimension,
517+
{
518+
if a.ndim() <= 1 {
519+
return;
520+
}
521+
sort_axes1_impl(&mut a.dim, &mut a.strides);
522+
}
523+
524+
fn sort_axes_to_standard_order_tandem<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
340525
where
341526
S: RawData,
342527
S2: RawData,
@@ -350,6 +535,32 @@ where
350535
a.shape(), a.strides());
351536
}
352537

538+
fn sort_axes1_impl<D>(adim: &mut D, astrides: &mut D)
539+
where
540+
D: Dimension,
541+
{
542+
debug_assert!(adim.ndim() > 1);
543+
debug_assert_eq!(adim.ndim(), astrides.ndim());
544+
// bubble sort axes
545+
let mut changed = true;
546+
while changed {
547+
changed = false;
548+
for i in 0..adim.ndim() - 1 {
549+
let axis_i = i;
550+
let next_axis = i + 1;
551+
552+
// make sure higher stride axes sort before.
553+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
554+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
555+
changed = true;
556+
adim.slice_mut().swap(axis_i, next_axis);
557+
astrides.slice_mut().swap(axis_i, next_axis);
558+
}
559+
}
560+
}
561+
}
562+
563+
353564
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
354565
where
355566
D: Dimension,

tests/append.rs

+22-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ fn append_row() {
1818
assert_eq!(a.try_append_column(aview1(&[1.])),
1919
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
2020
assert_eq!(a.try_append_column(aview1(&[1., 2.])),
21-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
21+
Ok(()));
22+
assert_eq!(a,
23+
array![[0., 1., 2., 3., 1.],
24+
[4., 5., 6., 7., 2.]]);
2225
}
2326

2427
#[test]
@@ -28,8 +31,7 @@ fn append_row_wrong_layout() {
2831
a.try_append_row(aview1(&[4., 5., 6., 7.])).unwrap();
2932
assert_eq!(a.shape(), &[2, 4]);
3033

31-
assert_eq!(a.try_append_column(aview1(&[1., 2.])),
32-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
34+
//assert_eq!(a.try_append_column(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
3335

3436
assert_eq!(a,
3537
array![[0., 1., 2., 3.],
@@ -56,7 +58,13 @@ fn append_row_error() {
5658
assert_eq!(a.try_append_column(aview1(&[1.])),
5759
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
5860
assert_eq!(a.try_append_column(aview1(&[1., 2., 3.])),
59-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
61+
Ok(()));
62+
assert_eq!(a.t(),
63+
array![[0., 0., 0.],
64+
[0., 0., 0.],
65+
[0., 0., 0.],
66+
[0., 0., 0.],
67+
[1., 2., 3.]]);
6068
}
6169

6270
#[test]
@@ -76,7 +84,11 @@ fn append_row_existing() {
7684
assert_eq!(a.try_append_column(aview1(&[1.])),
7785
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
7886
assert_eq!(a.try_append_column(aview1(&[1., 2., 3.])),
79-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
87+
Ok(()));
88+
assert_eq!(a,
89+
array![[0., 0., 0., 0., 1.],
90+
[0., 1., 2., 3., 2.],
91+
[4., 5., 6., 7., 3.]]);
8092
}
8193

8294
#[test]
@@ -87,8 +99,7 @@ fn append_row_col_len_1() {
8799
a.try_append_column(aview1(&[2., 3.])).unwrap(); // shape 2 x 2
88100
assert_eq!(a.try_append_row(aview1(&[1.])),
89101
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
90-
assert_eq!(a.try_append_row(aview1(&[1., 2.])),
91-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
102+
//assert_eq!(a.try_append_row(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
92103
a.try_append_column(aview1(&[4., 5.])).unwrap(); // shape 2 x 3
93104
assert_eq!(a.shape(), &[2, 3]);
94105

@@ -240,3 +251,7 @@ fn test_append_zero_size() {
240251
assert_eq!(a.shape(), &[0, 2]);
241252
}
242253
}
254+
255+
#[test]
256+
fn move_into() {
257+
}

0 commit comments

Comments
 (0)