Skip to content

Commit 352aceb

Browse files
committed
FIX: Solve axis iteration order problem by sorting axes
1 parent 751c25e commit 352aceb

File tree

3 files changed

+110
-17
lines changed

3 files changed

+110
-17
lines changed

src/impl_owned_array.rs

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use crate::dimension;
66
use crate::error::{ErrorKind, ShapeError};
77
use crate::OwnedRepr;
88
use crate::Zip;
9-
use crate::NdProducer;
109

1110
/// Methods specific to `Array0`.
1211
///
@@ -253,15 +252,12 @@ impl<A, D> Array<A, D>
253252
/// [1., 1., 1., 1.],
254253
/// [1., 1., 1., 1.]]);
255254
/// ```
256-
pub fn try_append_array(&mut self, axis: Axis, array: ArrayView<A, D>)
255+
pub fn try_append_array(&mut self, axis: Axis, mut array: ArrayView<A, D>)
257256
-> Result<(), ShapeError>
258257
where
259258
A: Clone,
260259
D: RemoveAxis,
261260
{
262-
let self_axis_len = self.len_of(axis);
263-
let array_axis_len = array.len_of(axis);
264-
265261
let remaining_shape = self.raw_dim().remove_axis(axis);
266262
let array_rem_shape = array.raw_dim().remove_axis(axis);
267263

@@ -312,7 +308,7 @@ impl<A, D> Array<A, D>
312308
// make a raw view with the new row
313309
// safe because the data was "full"
314310
let tail_ptr = self.data.as_end_nonnull();
315-
let tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
311+
let mut tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
316312

317313
struct SetLenOnDrop<'a, A: 'a> {
318314
len: usize,
@@ -332,37 +328,86 @@ impl<A, D> Array<A, D>
332328
}
333329
}
334330

335-
// we have a problem here XXX
336-
//
337331
// To be robust for panics and drop the right elements, we want
338332
// to fill the tail in-order, so that we can drop the right elements on
339-
// panic. Don't know how to achieve that.
333+
// panic.
340334
//
341-
// It might be easier to retrace our steps in a scope guard to drop the right
342-
// elements.. (PartialArray style).
335+
// We have: Zip::from(tail_view).and(array)
336+
// Transform tail_view into standard order by inverting and moving its axes.
337+
// Keep the Zip traversal unchanged by applying the same axis transformations to
338+
// `array`. This ensures the Zip traverses the underlying memory in order.
343339
//
344-
// assign the new elements
340+
// XXX It would be possible to skip this transformation if the element
341+
// doesn't have drop. However, in the interest of code coverage, all elements
342+
// use this code initially.
343+
344+
if tail_view.ndim() > 1 {
345+
for i in 0..tail_view.ndim() {
346+
if tail_view.stride_of(Axis(i)) < 0 {
347+
tail_view.invert_axis(Axis(i));
348+
array.invert_axis(Axis(i));
349+
}
350+
}
351+
sort_axes_to_standard_order(&mut tail_view, &mut array);
352+
}
345353
Zip::from(tail_view).and(array)
354+
.debug_assert_c_order()
346355
.for_each(|to, from| {
347356
to.write(from.clone());
348357
length_guard.len += 1;
349358
});
350359

351-
//length_guard.len += len_to_append;
352-
dbg!(len_to_append);
353360
drop(length_guard);
354361

355362
// update array dimension
356363
self.strides = strides;
357364
self.dim = res_dim;
358-
dbg!(&self.dim);
359-
360365
}
361366
// multiple assertions after pointer & dimension update
362367
debug_assert_eq!(self.data.len(), self.len());
363368
debug_assert_eq!(self.len(), new_len);
364-
debug_assert!(self.is_standard_layout());
365369

366370
Ok(())
367371
}
368372
}
373+
374+
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
375+
where
376+
S: RawData,
377+
S2: RawData,
378+
D: Dimension,
379+
{
380+
if a.ndim() <= 1 {
381+
return;
382+
}
383+
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
384+
debug_assert!(a.is_standard_layout());
385+
}
386+
387+
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
388+
where
389+
D: Dimension,
390+
{
391+
debug_assert!(adim.ndim() > 1);
392+
debug_assert_eq!(adim.ndim(), bdim.ndim());
393+
// bubble sort axes
394+
let mut changed = true;
395+
while changed {
396+
changed = false;
397+
for i in 0..adim.ndim() - 1 {
398+
let axis_i = i;
399+
let next_axis = i + 1;
400+
401+
// make sure higher stride axes sort before.
402+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
403+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
404+
changed = true;
405+
adim.slice_mut().swap(axis_i, next_axis);
406+
astrides.slice_mut().swap(axis_i, next_axis);
407+
bdim.slice_mut().swap(axis_i, next_axis);
408+
bstrides.slice_mut().swap(axis_i, next_axis);
409+
}
410+
}
411+
}
412+
}
413+

src/zip/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ macro_rules! map_impl {
673673
self.build_and(part)
674674
}
675675

676+
#[allow(unused)]
677+
#[inline]
678+
pub(crate) fn debug_assert_c_order(self) -> Self {
679+
debug_assert!(self.layout.is(CORDER) || self.layout_tendency >= 0);
680+
self
681+
}
682+
676683
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
677684
where P: NdProducer<Dim=D>,
678685
{

tests/append.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,44 @@ fn append_array1() {
8787
[5., 5., 4., 4.],
8888
[3., 3., 2., 2.]]);
8989
}
90+
91+
#[test]
92+
fn append_array_3d() {
93+
let mut a = Array::zeros((0, 2, 2));
94+
a.try_append_array(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap();
95+
println!("{:?}", a);
96+
97+
let aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]];
98+
let av = aa.view();
99+
println!("Send {:?} to append", av);
100+
a.try_append_array(Axis(0), av.clone()).unwrap();
101+
102+
a.swap_axes(0, 1);
103+
let aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]];
104+
let mut av = aa.view();
105+
av.swap_axes(0, 1);
106+
println!("Send {:?} to append", av);
107+
a.try_append_array(Axis(1), av.clone()).unwrap();
108+
println!("{:?}", a);
109+
let aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]];
110+
let mut av = aa.view();
111+
av.swap_axes(0, 1);
112+
println!("Send {:?} to append", av);
113+
a.try_append_array(Axis(1), av).unwrap();
114+
println!("{:?}", a);
115+
assert_eq!(a,
116+
array![[[0, 1],
117+
[51, 52],
118+
[55, 56],
119+
[71, 72],
120+
[75, 76],
121+
[81, 82],
122+
[85, 86]],
123+
[[2, 3],
124+
[53, 54],
125+
[57, 58],
126+
[73, 74],
127+
[77, 78],
128+
[83, 84],
129+
[87, 88]]]);
130+
}

0 commit comments

Comments
 (0)