Skip to content

Commit b8aea7a

Browse files
committed
FIX: Fix situations where we need to recompute stride
When the axis has length 0, or 1, we need to carefully compute new strides.
1 parent 352aceb commit b8aea7a

File tree

2 files changed

+86
-10
lines changed

2 files changed

+86
-10
lines changed

src/impl_owned_array.rs

+37-10
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ impl<A, D> Array<A, D>
258258
A: Clone,
259259
D: RemoveAxis,
260260
{
261+
assert_ne!(self.ndim(), 0, "Impossible to append to 0-dim array");
262+
let current_axis_len = self.len_of(axis);
261263
let remaining_shape = self.raw_dim().remove_axis(axis);
262264
let array_rem_shape = array.raw_dim().remove_axis(axis);
263265

@@ -277,22 +279,46 @@ impl<A, D> Array<A, D>
277279

278280
let self_is_empty = self.is_empty();
279281

280-
// array must be empty or have `axis` as the outermost (longest stride)
281-
// axis
282-
if !(self_is_empty ||
283-
self.axes().max_by_key(|ax| ax.stride).map(|ax| ax.axis) == Some(axis))
284-
{
285-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
282+
// array must be empty or have `axis` as the outermost (longest stride) axis
283+
if !self_is_empty && current_axis_len > 1 {
284+
// `axis` must be max stride axis or equal to its stride
285+
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
286+
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
287+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
288+
}
286289
}
287290

288291
// array must be be "full" (have no exterior holes)
289292
if self.len() != self.data.len() {
290293
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
291294
}
295+
292296
let strides = if self_is_empty {
293-
// recompute strides - if the array was previously empty, it could have
294-
// zeros in strides.
295-
res_dim.default_strides()
297+
// recompute strides - if the array was previously empty, it could have zeros in
298+
// strides.
299+
// The new order is based on c/f-contig but must have `axis` as outermost axis.
300+
if axis == Axis(self.ndim() - 1) {
301+
// prefer f-contig when appending to the last axis
302+
// Axis n - 1 is outermost axis
303+
res_dim.fortran_strides()
304+
} else {
305+
// Default with modification
306+
res_dim.slice_mut().swap(0, axis.index());
307+
let mut strides = res_dim.default_strides();
308+
res_dim.slice_mut().swap(0, axis.index());
309+
strides.slice_mut().swap(0, axis.index());
310+
strides
311+
}
312+
} else if current_axis_len == 1 {
313+
// This is the outermost/longest stride axis; so we find the max across the other axes
314+
let new_stride = self.axes().fold(1, |acc, ax| {
315+
if ax.axis == axis { acc } else {
316+
Ord::max(acc, ax.len as isize * ax.stride)
317+
}
318+
});
319+
let mut strides = self.strides.clone();
320+
strides[axis.index()] = new_stride as usize;
321+
strides
296322
} else {
297323
self.strides.clone()
298324
};
@@ -381,7 +407,8 @@ where
381407
return;
382408
}
383409
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
384-
debug_assert!(a.is_standard_layout());
410+
debug_assert!(a.is_standard_layout(), "not std layout dim: {:?}, strides: {:?}",
411+
a.shape(), a.strides());
385412
}
386413

387414
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)

tests/append.rs

+49
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,52 @@ fn append_array_3d() {
128128
[83, 84],
129129
[87, 88]]]);
130130
}
131+
132+
#[test]
133+
fn test_append_2d() {
134+
// create an empty array and append
135+
let mut a = Array::zeros((0, 4));
136+
let ones = ArrayView::from(&[1.; 12]).into_shape((3, 4)).unwrap();
137+
let zeros = ArrayView::from(&[0.; 8]).into_shape((2, 4)).unwrap();
138+
a.try_append_array(Axis(0), ones).unwrap();
139+
a.try_append_array(Axis(0), zeros).unwrap();
140+
a.try_append_array(Axis(0), ones).unwrap();
141+
println!("{:?}", a);
142+
assert_eq!(a.shape(), &[8, 4]);
143+
for (i, row) in a.rows().into_iter().enumerate() {
144+
let ones = i < 3 || i >= 5;
145+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
146+
}
147+
148+
let mut a = Array::zeros((0, 4));
149+
a = a.reversed_axes();
150+
let ones = ones.reversed_axes();
151+
let zeros = zeros.reversed_axes();
152+
a.try_append_array(Axis(1), ones).unwrap();
153+
a.try_append_array(Axis(1), zeros).unwrap();
154+
a.try_append_array(Axis(1), ones).unwrap();
155+
println!("{:?}", a);
156+
assert_eq!(a.shape(), &[4, 8]);
157+
158+
for (i, row) in a.columns().into_iter().enumerate() {
159+
let ones = i < 3 || i >= 5;
160+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
161+
}
162+
}
163+
164+
#[test]
165+
fn test_append_middle_axis() {
166+
// ensure we can append to Axis(1) by letting it become outermost
167+
let mut a = Array::<i32, _>::zeros((3, 0, 2));
168+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
169+
println!("{:?}", a);
170+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
171+
println!("{:?}", a);
172+
173+
// ensure we can append to Axis(1) by letting it become outermost
174+
let mut a = Array::<i32, _>::zeros((3, 1, 2));
175+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
176+
println!("{:?}", a);
177+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
178+
println!("{:?}", a);
179+
}

0 commit comments

Comments
 (0)