Skip to content

Commit 8bd70b0

Browse files
add test case for partition on empty array (#1504)
* add test case for empty array * return early when the array has zero lenth dims --------- Co-authored-by: Adam Kern <[email protected]>
1 parent da115c9 commit 8bd70b0

File tree

1 file changed

+93
-48
lines changed

1 file changed

+93
-48
lines changed

src/impl_methods.rs

+93-48
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,7 @@ where
576576
pub fn slice_move<I>(mut self, info: I) -> ArrayBase<S, I::OutDim>
577577
where I: SliceArg<D>
578578
{
579-
assert_eq!(
580-
info.in_ndim(),
581-
self.ndim(),
582-
"The input dimension of `info` must match the array to be sliced.",
583-
);
579+
assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",);
584580
let out_ndim = info.out_ndim();
585581
let mut new_dim = I::OutDim::zeros(out_ndim);
586582
let mut new_strides = I::OutDim::zeros(out_ndim);
@@ -648,11 +644,7 @@ impl<A, D: Dimension> LayoutRef<A, D>
648644
pub fn slice_collapse<I>(&mut self, info: I)
649645
where I: SliceArg<D>
650646
{
651-
assert_eq!(
652-
info.in_ndim(),
653-
self.ndim(),
654-
"The input dimension of `info` must match the array to be sliced.",
655-
);
647+
assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",);
656648
let mut axis = 0;
657649
info.as_ref().iter().for_each(|&ax_info| match ax_info {
658650
SliceInfoElem::Slice { start, end, step } => {
@@ -1120,8 +1112,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
11201112
// bounds check the indices first
11211113
if let Some(max_index) = indices.iter().cloned().max() {
11221114
if max_index >= axis_len {
1123-
panic!("ndarray: index {} is out of bounds in array of len {}",
1124-
max_index, self.len_of(axis));
1115+
panic!("ndarray: index {} is out of bounds in array of len {}", max_index, self.len_of(axis));
11251116
}
11261117
} // else: indices empty is ok
11271118
let view = self.view().into_dimensionality::<Ix1>().unwrap();
@@ -1530,10 +1521,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
15301521

15311522
ndassert!(
15321523
axis_index < self.ndim(),
1533-
concat!(
1534-
"Window axis {} does not match array dimension {} ",
1535-
"(with array of shape {:?})"
1536-
),
1524+
concat!("Window axis {} does not match array dimension {} ", "(with array of shape {:?})"),
15371525
axis_index,
15381526
self.ndim(),
15391527
self.shape()
@@ -3119,8 +3107,7 @@ where
31193107
/// ***Panics*** if not `index < self.len_of(axis)`.
31203108
pub fn remove_index(&mut self, axis: Axis, index: usize)
31213109
{
3122-
assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})",
3123-
index, axis.index());
3110+
assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index());
31243111
let (_, mut tail) = self.view_mut().split_at(axis, index);
31253112
// shift elements to the front
31263113
Zip::from(tail.lanes_mut(axis)).for_each(|mut lane| lane.rotate1_front());
@@ -3193,15 +3180,16 @@ impl<A, D: Dimension> ArrayRef<A, D>
31933180
/// - All elements equal or greater than the k-th element to its right
31943181
/// - The ordering within each partition is undefined
31953182
///
3183+
/// Empty arrays (i.e., those with any zero-length axes) are considered partitioned already,
3184+
/// and will be returned unchanged.
3185+
///
3186+
/// **Panics** if `k` is out of bounds for a non-zero axis length.
3187+
///
31963188
/// # Parameters
31973189
///
31983190
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
31993191
/// * `axis` - Axis along which to partition.
32003192
///
3201-
/// # Returns
3202-
///
3203-
/// A new array of the same shape and type as the input array, with elements partitioned.
3204-
///
32053193
/// # Examples
32063194
///
32073195
/// ```
@@ -3221,19 +3209,19 @@ impl<A, D: Dimension> ArrayRef<A, D>
32213209
A: Clone + Ord + num_traits::Zero,
32223210
D: Dimension,
32233211
{
3224-
// Bounds checking
3225-
let axis_len = self.len_of(axis);
3226-
if kth >= axis_len {
3227-
panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len);
3228-
}
3229-
32303212
let mut result = self.to_owned();
32313213

3232-
// Must guarantee that the array isn't empty before checking for contiguity
3233-
if result.shape().iter().any(|s| *s == 0) {
3214+
// Return early if the array has zero-length dimensions
3215+
if self.shape().iter().any(|s| *s == 0) {
32343216
return result;
32353217
}
32363218

3219+
// Bounds checking. Panics if kth is out of bounds
3220+
let axis_len = self.len_of(axis);
3221+
if kth >= axis_len {
3222+
panic!("Partition index {} is out of bounds for axis {} of length {}", kth, axis.0, axis_len);
3223+
}
3224+
32373225
// Check if the first lane is contiguous
32383226
let is_contiguous = result
32393227
.lanes_mut(axis)
@@ -3428,11 +3416,7 @@ mod tests
34283416
fn test_partition_contiguous_or_not()
34293417
{
34303418
// Test contiguous case (C-order)
3431-
let a = array![
3432-
[7, 1, 5],
3433-
[2, 6, 0],
3434-
[3, 4, 8]
3435-
];
3419+
let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]];
34363420

34373421
// Partition along axis 0 (contiguous)
34383422
let p_axis0 = a.partition(1, Axis(0));
@@ -3442,20 +3426,24 @@ mod tests
34423426
// - Last row should be >= middle row (kth element)
34433427
for col in 0..3 {
34443428
let kth = p_axis0[[1, col]];
3445-
assert!(p_axis0[[0, col]] <= kth,
3429+
assert!(
3430+
p_axis0[[0, col]] <= kth,
34463431
"Column {}: First row {} should be <= middle row {}",
3447-
col, p_axis0[[0, col]], kth);
3448-
assert!(p_axis0[[2, col]] >= kth,
3432+
col,
3433+
p_axis0[[0, col]],
3434+
kth
3435+
);
3436+
assert!(
3437+
p_axis0[[2, col]] >= kth,
34493438
"Column {}: Last row {} should be >= middle row {}",
3450-
col, p_axis0[[2, col]], kth);
3439+
col,
3440+
p_axis0[[2, col]],
3441+
kth
3442+
);
34513443
}
34523444

34533445
// Test non-contiguous case (F-order)
3454-
let a = array![
3455-
[7, 1, 5],
3456-
[2, 6, 0],
3457-
[3, 4, 8]
3458-
];
3446+
let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]];
34593447

34603448
// Make array non-contiguous by transposing
34613449
let a = a.t().to_owned();
@@ -3467,12 +3455,69 @@ mod tests
34673455
// - First column should be <= middle column
34683456
// - Last column should be >= middle column
34693457
for row in 0..3 {
3470-
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]],
3458+
assert!(
3459+
p_axis1[[row, 0]] <= p_axis1[[row, 1]],
34713460
"Row {}: First column {} should be <= middle column {}",
3472-
row, p_axis1[[row, 0]], p_axis1[[row, 1]]);
3473-
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]],
3461+
row,
3462+
p_axis1[[row, 0]],
3463+
p_axis1[[row, 1]]
3464+
);
3465+
assert!(
3466+
p_axis1[[row, 2]] >= p_axis1[[row, 1]],
34743467
"Row {}: Last column {} should be >= middle column {}",
3475-
row, p_axis1[[row, 2]], p_axis1[[row, 1]]);
3468+
row,
3469+
p_axis1[[row, 2]],
3470+
p_axis1[[row, 1]]
3471+
);
34763472
}
34773473
}
3474+
3475+
#[test]
3476+
fn test_partition_empty()
3477+
{
3478+
// Test 1D empty array
3479+
let empty1d = Array1::<i32>::zeros(0);
3480+
let result1d = empty1d.partition(0, Axis(0));
3481+
assert_eq!(result1d.len(), 0);
3482+
3483+
// Test 1D empty array with kth out of bounds
3484+
let result1d_out_of_bounds = empty1d.partition(1, Axis(0));
3485+
assert_eq!(result1d_out_of_bounds.len(), 0);
3486+
3487+
// Test 2D empty array
3488+
let empty2d = Array2::<i32>::zeros((0, 3));
3489+
let result2d = empty2d.partition(0, Axis(0));
3490+
assert_eq!(result2d.shape(), &[0, 3]);
3491+
3492+
// Test 2D empty array with zero columns
3493+
let empty2d_cols = Array2::<i32>::zeros((2, 0));
3494+
let result2d_cols = empty2d_cols.partition(0, Axis(1));
3495+
assert_eq!(result2d_cols.shape(), &[2, 0]);
3496+
3497+
// Test 3D empty array
3498+
let empty3d = Array3::<i32>::zeros((0, 2, 3));
3499+
let result3d = empty3d.partition(0, Axis(0));
3500+
assert_eq!(result3d.shape(), &[0, 2, 3]);
3501+
3502+
// Test 3D empty array with zero in middle dimension
3503+
let empty3d_mid = Array3::<i32>::zeros((2, 0, 3));
3504+
let result3d_mid = empty3d_mid.partition(0, Axis(1));
3505+
assert_eq!(result3d_mid.shape(), &[2, 0, 3]);
3506+
3507+
// Test 4D empty array
3508+
let empty4d = Array4::<i32>::zeros((0, 2, 3, 4));
3509+
let result4d = empty4d.partition(0, Axis(0));
3510+
assert_eq!(result4d.shape(), &[0, 2, 3, 4]);
3511+
3512+
// Test empty array with non-zero dimensions in other axes
3513+
let empty_mixed = Array2::<i32>::zeros((0, 5));
3514+
let result_mixed = empty_mixed.partition(0, Axis(0));
3515+
assert_eq!(result_mixed.shape(), &[0, 5]);
3516+
3517+
// Test empty array with negative strides
3518+
let arr = Array2::<i32>::zeros((3, 3));
3519+
let empty_slice = arr.slice(s![0..0, ..]);
3520+
let result_slice = empty_slice.partition(0, Axis(0));
3521+
assert_eq!(result_slice.shape(), &[0, 3]);
3522+
}
34783523
}

0 commit comments

Comments
 (0)