Skip to content

Commit 2e8d2fc

Browse files
committed
FIX: Fix situations where we need to recompute stride
When the axis has length 0, or 1, we need to carefully compute new strides.
1 parent 7d5c3d3 commit 2e8d2fc

File tree

2 files changed

+86
-10
lines changed

2 files changed

+86
-10
lines changed

src/impl_owned_array.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ impl<A, D> Array<A, D>
256256
A: Clone,
257257
D: RemoveAxis,
258258
{
259+
assert_ne!(self.ndim(), 0, "Impossible to append to 0-dim array");
260+
let current_axis_len = self.len_of(axis);
259261
let remaining_shape = self.raw_dim().remove_axis(axis);
260262
let array_rem_shape = array.raw_dim().remove_axis(axis);
261263

@@ -275,22 +277,46 @@ impl<A, D> Array<A, D>
275277

276278
let self_is_empty = self.is_empty();
277279

278-
// array must be empty or have `axis` as the outermost (longest stride)
279-
// axis
280-
if !(self_is_empty ||
281-
self.axes().max_by_key(|ax| ax.stride).map(|ax| ax.axis) == Some(axis))
282-
{
283-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
280+
// array must be empty or have `axis` as the outermost (longest stride) axis
281+
if !self_is_empty && current_axis_len > 1 {
282+
// `axis` must be max stride axis or equal to its stride
283+
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
284+
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
285+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
286+
}
284287
}
285288

286289
// array must be be "full" (have no exterior holes)
287290
if self.len() != self.data.len() {
288291
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
289292
}
293+
290294
let strides = if self_is_empty {
291-
// recompute strides - if the array was previously empty, it could have
292-
// zeros in strides.
293-
res_dim.default_strides()
295+
// recompute strides - if the array was previously empty, it could have zeros in
296+
// strides.
297+
// The new order is based on c/f-contig but must have `axis` as outermost axis.
298+
if axis == Axis(self.ndim() - 1) {
299+
// prefer f-contig when appending to the last axis
300+
// Axis n - 1 is outermost axis
301+
res_dim.fortran_strides()
302+
} else {
303+
// Default with modification
304+
res_dim.slice_mut().swap(0, axis.index());
305+
let mut strides = res_dim.default_strides();
306+
res_dim.slice_mut().swap(0, axis.index());
307+
strides.slice_mut().swap(0, axis.index());
308+
strides
309+
}
310+
} else if current_axis_len == 1 {
311+
// This is the outermost/longest stride axis; so we find the max across the other axes
312+
let new_stride = self.axes().fold(1, |acc, ax| {
313+
if ax.axis == axis { acc } else {
314+
Ord::max(acc, ax.len as isize * ax.stride)
315+
}
316+
});
317+
let mut strides = self.strides.clone();
318+
strides[axis.index()] = new_stride as usize;
319+
strides
294320
} else {
295321
self.strides.clone()
296322
};
@@ -379,7 +405,8 @@ where
379405
return;
380406
}
381407
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
382-
debug_assert!(a.is_standard_layout());
408+
debug_assert!(a.is_standard_layout(), "not std layout dim: {:?}, strides: {:?}",
409+
a.shape(), a.strides());
383410
}
384411

385412
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)

tests/append.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,52 @@ fn append_array_3d() {
128128
[83, 84],
129129
[87, 88]]]);
130130
}
131+
132+
#[test]
133+
fn test_append_2d() {
134+
// create an empty array and append
135+
let mut a = Array::zeros((0, 4));
136+
let ones = ArrayView::from(&[1.; 12]).into_shape((3, 4)).unwrap();
137+
let zeros = ArrayView::from(&[0.; 8]).into_shape((2, 4)).unwrap();
138+
a.try_append_array(Axis(0), ones).unwrap();
139+
a.try_append_array(Axis(0), zeros).unwrap();
140+
a.try_append_array(Axis(0), ones).unwrap();
141+
println!("{:?}", a);
142+
assert_eq!(a.shape(), &[8, 4]);
143+
for (i, row) in a.rows().into_iter().enumerate() {
144+
let ones = i < 3 || i >= 5;
145+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
146+
}
147+
148+
let mut a = Array::zeros((0, 4));
149+
a = a.reversed_axes();
150+
let ones = ones.reversed_axes();
151+
let zeros = zeros.reversed_axes();
152+
a.try_append_array(Axis(1), ones).unwrap();
153+
a.try_append_array(Axis(1), zeros).unwrap();
154+
a.try_append_array(Axis(1), ones).unwrap();
155+
println!("{:?}", a);
156+
assert_eq!(a.shape(), &[4, 8]);
157+
158+
for (i, row) in a.columns().into_iter().enumerate() {
159+
let ones = i < 3 || i >= 5;
160+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
161+
}
162+
}
163+
164+
#[test]
165+
fn test_append_middle_axis() {
166+
// ensure we can append to Axis(1) by letting it become outermost
167+
let mut a = Array::<i32, _>::zeros((3, 0, 2));
168+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
169+
println!("{:?}", a);
170+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
171+
println!("{:?}", a);
172+
173+
// ensure we can append to Axis(1) by letting it become outermost
174+
let mut a = Array::<i32, _>::zeros((3, 1, 2));
175+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
176+
println!("{:?}", a);
177+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
178+
println!("{:?}", a);
179+
}

0 commit comments

Comments
 (0)