Skip to content

Commit 769bd77

Browse files
committed
FEAT: Fixup layout in .append_array()
1 parent 06b5026 commit 769bd77

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

src/impl_owned_array.rs

+49-2
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,45 @@ impl<A, D> Array<A, D>
258258
self.data
259259
}
260260

261+
/// Create an empty array with an all-zeros shape
262+
///
263+
/// ***Panics*** if D is zero-dimensional, because it can't be empty
264+
pub(crate) fn empty() -> Array<A, D> {
265+
assert_ne!(D::NDIM, Some(0));
266+
let ndim = D::NDIM.unwrap_or(1);
267+
Array::from_shape_simple_fn(D::zeros(ndim), || unreachable!())
268+
}
269+
270+
/// Create new_array with the right layout for appending to `growing_axis`
271+
#[inline(never)]
272+
fn change_to_contig_append_layout(&mut self, growing_axis: Axis) {
273+
let ndim = self.ndim();
274+
let mut dim = self.raw_dim();
275+
276+
// The array will be created with 0 (C) or ndim-1 (F) as the biggest stride
277+
// axis. Rearrange the shape so that `growing_axis` is the biggest stride axis
278+
// afterwards.
279+
let prefer_f_layout = growing_axis == Axis(ndim - 1);
280+
if !prefer_f_layout {
281+
dim.slice_mut().swap(0, growing_axis.index());
282+
}
283+
let mut new_array = Self::uninit(dim.set_f(prefer_f_layout));
284+
if !prefer_f_layout {
285+
new_array.swap_axes(0, growing_axis.index());
286+
}
287+
288+
// self -> old_self.
289+
// dummy array -> self.
290+
// old_self elements are moved -> new_array.
291+
let old_self = std::mem::replace(self, Self::empty());
292+
old_self.move_into(new_array.view_mut());
293+
294+
// new_array -> self.
295+
unsafe {
296+
*self = new_array.assume_init();
297+
}
298+
}
299+
261300

262301
/// Append an array to the array
263302
///
@@ -339,19 +378,27 @@ impl<A, D> Array<A, D>
339378
}
340379

341380
let self_is_empty = self.is_empty();
381+
let mut incompatible_layout = false;
342382

343383
// array must be empty or have `axis` as the outermost (longest stride) axis
344384
if !self_is_empty && current_axis_len > 1 {
345385
// `axis` must be max stride axis or equal to its stride
346386
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
347387
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
348-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
388+
incompatible_layout = true;
349389
}
350390
}
351391

352392
// array must be be "full" (have no exterior holes)
353393
if self.len() != self.data.len() {
354-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
394+
incompatible_layout = true;
395+
}
396+
397+
if incompatible_layout {
398+
self.change_to_contig_append_layout(axis);
399+
// safety-check parameters after remodeling
400+
debug_assert_eq!(self_is_empty, self.is_empty());
401+
debug_assert_eq!(current_axis_len, self.len_of(axis));
355402
}
356403

357404
let strides = if self_is_empty {

tests/append.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ fn append_row() {
1818
assert_eq!(a.try_append_column(aview1(&[1.])),
1919
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
2020
assert_eq!(a.try_append_column(aview1(&[1., 2.])),
21-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
21+
Ok(()));
22+
assert_eq!(a,
23+
array![[0., 1., 2., 3., 1.],
24+
[4., 5., 6., 7., 2.]]);
2225
}
2326

2427
#[test]
@@ -28,8 +31,7 @@ fn append_row_wrong_layout() {
2831
a.try_append_row(aview1(&[4., 5., 6., 7.])).unwrap();
2932
assert_eq!(a.shape(), &[2, 4]);
3033

31-
assert_eq!(a.try_append_column(aview1(&[1., 2.])),
32-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
34+
//assert_eq!(a.try_append_column(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
3335

3436
assert_eq!(a,
3537
array![[0., 1., 2., 3.],
@@ -56,7 +58,13 @@ fn append_row_error() {
5658
assert_eq!(a.try_append_column(aview1(&[1.])),
5759
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
5860
assert_eq!(a.try_append_column(aview1(&[1., 2., 3.])),
59-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
61+
Ok(()));
62+
assert_eq!(a.t(),
63+
array![[0., 0., 0.],
64+
[0., 0., 0.],
65+
[0., 0., 0.],
66+
[0., 0., 0.],
67+
[1., 2., 3.]]);
6068
}
6169

6270
#[test]
@@ -76,7 +84,11 @@ fn append_row_existing() {
7684
assert_eq!(a.try_append_column(aview1(&[1.])),
7785
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
7886
assert_eq!(a.try_append_column(aview1(&[1., 2., 3.])),
79-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
87+
Ok(()));
88+
assert_eq!(a,
89+
array![[0., 0., 0., 0., 1.],
90+
[0., 1., 2., 3., 2.],
91+
[4., 5., 6., 7., 3.]]);
8092
}
8193

8294
#[test]
@@ -87,8 +99,7 @@ fn append_row_col_len_1() {
8799
a.try_append_column(aview1(&[2., 3.])).unwrap(); // shape 2 x 2
88100
assert_eq!(a.try_append_row(aview1(&[1.])),
89101
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
90-
assert_eq!(a.try_append_row(aview1(&[1., 2.])),
91-
Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
102+
//assert_eq!(a.try_append_row(aview1(&[1., 2.])), Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)));
92103
a.try_append_column(aview1(&[4., 5.])).unwrap(); // shape 2 x 3
93104
assert_eq!(a.shape(), &[2, 3]);
94105

0 commit comments

Comments
 (0)