Skip to content

Commit 1206bd6

Browse files
blussjturner314
andcommitted
append: Use positive stride and ignore stride of len 1 axes
Fix two bugs that Jim found in how we calculate the new stride for the growing axis. Tests by Jim Turner. Co-authored-by: Jim Turner <[email protected]>
1 parent 3f59442 commit 1206bd6

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/impl_owned_array.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,11 @@ impl<A, D> Array<A, D>
480480
} else if current_axis_len == 1 {
481481
// This is the outermost/longest stride axis; so we find the max across the other axes
482482
let new_stride = self.axes().fold(1, |acc, ax| {
483-
if ax.axis == axis { acc } else {
484-
let this_ax = ax.len as isize * ax.stride;
485-
if this_ax.abs() > acc { this_ax } else { acc }
483+
if ax.axis == axis || ax.len <= 1 {
484+
acc
485+
} else {
486+
let this_ax = ax.len as isize * ax.stride.abs();
487+
if this_ax > acc { this_ax } else { acc }
486488
}
487489
});
488490
let mut strides = self.strides.clone();

tests/append.rs

+31
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,34 @@ fn test_append_zero_size() {
345345
assert_eq!(a.shape(), &[0, 2]);
346346
}
347347
}
348+
349+
#[test]
350+
fn append_row_neg_stride_3() {
351+
let mut a = Array::zeros((0, 4));
352+
a.append_row(aview1(&[0., 1., 2., 3.])).unwrap();
353+
a.invert_axis(Axis(1));
354+
a.append_row(aview1(&[4., 5., 6., 7.])).unwrap();
355+
assert_eq!(a.shape(), &[2, 4]);
356+
assert_eq!(a, array![[3., 2., 1., 0.], [4., 5., 6., 7.]]);
357+
assert_eq!(a.strides(), &[4, -1]);
358+
}
359+
360+
#[test]
361+
fn append_row_ignore_strides_length_one_axes() {
362+
let strides = &[0, 1, 10, 20];
363+
for invert in &[vec![], vec![0], vec![1], vec![0, 1]] {
364+
for &stride0 in strides {
365+
for &stride1 in strides {
366+
let mut a =
367+
Array::from_shape_vec([1, 1].strides([stride0, stride1]), vec![0.]).unwrap();
368+
for &ax in invert {
369+
a.invert_axis(Axis(ax));
370+
}
371+
a.append_row(aview1(&[1.])).unwrap();
372+
assert_eq!(a.shape(), &[2, 1]);
373+
assert_eq!(a, array![[0.], [1.]]);
374+
assert_eq!(a.stride_of(Axis(0)), 1);
375+
}
376+
}
377+
}
378+
}

0 commit comments

Comments
 (0)