diff --git a/src/impl_methods.rs b/src/impl_methods.rs index ea9c9a0d5..9a1741be6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -576,11 +576,7 @@ where pub fn slice_move(mut self, info: I) -> ArrayBase where I: SliceArg { - assert_eq!( - info.in_ndim(), - self.ndim(), - "The input dimension of `info` must match the array to be sliced.", - ); + assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let out_ndim = info.out_ndim(); let mut new_dim = I::OutDim::zeros(out_ndim); let mut new_strides = I::OutDim::zeros(out_ndim); @@ -648,11 +644,7 @@ impl LayoutRef pub fn slice_collapse(&mut self, info: I) where I: SliceArg { - assert_eq!( - info.in_ndim(), - self.ndim(), - "The input dimension of `info` must match the array to be sliced.", - ); + assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let mut axis = 0; info.as_ref().iter().for_each(|&ax_info| match ax_info { SliceInfoElem::Slice { start, end, step } => { @@ -1120,8 +1112,7 @@ impl ArrayRef // bounds check the indices first if let Some(max_index) = indices.iter().cloned().max() { if max_index >= axis_len { - panic!("ndarray: index {} is out of bounds in array of len {}", - max_index, self.len_of(axis)); + panic!("ndarray: index {} is out of bounds in array of len {}", max_index, self.len_of(axis)); } } // else: indices empty is ok let view = self.view().into_dimensionality::().unwrap(); @@ -1530,10 +1521,7 @@ impl ArrayRef ndassert!( axis_index < self.ndim(), - concat!( - "Window axis {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Window axis {} does not match array dimension {} ", "(with array of shape {:?})"), axis_index, self.ndim(), self.shape() @@ -3119,8 +3107,7 @@ where /// ***Panics*** if not `index < self.len_of(axis)`. pub fn remove_index(&mut self, axis: Axis, index: usize) { - assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", - index, axis.index()); + assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index()); let (_, mut tail) = self.view_mut().split_at(axis, index); // shift elements to the front Zip::from(tail.lanes_mut(axis)).for_each(|mut lane| lane.rotate1_front()); @@ -3193,15 +3180,16 @@ impl ArrayRef /// - All elements equal or greater than the k-th element to its right /// - The ordering within each partition is undefined /// + /// Empty arrays (i.e., those with any zero-length axes) are considered partitioned already, + /// and will be returned unchanged. + /// + /// **Panics** if `k` is out of bounds for a non-zero axis length. + /// /// # Parameters /// /// * `kth` - Index to partition by. The k-th element will be in its sorted position. /// * `axis` - Axis along which to partition. /// - /// # Returns - /// - /// A new array of the same shape and type as the input array, with elements partitioned. - /// /// # Examples /// /// ``` @@ -3221,19 +3209,19 @@ impl ArrayRef A: Clone + Ord + num_traits::Zero, D: Dimension, { - // Bounds checking - let axis_len = self.len_of(axis); - if kth >= axis_len { - panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len); - } - let mut result = self.to_owned(); - // Must guarantee that the array isn't empty before checking for contiguity - if result.shape().iter().any(|s| *s == 0) { + // Return early if the array has zero-length dimensions + if self.shape().iter().any(|s| *s == 0) { return result; } + // Bounds checking. Panics if kth is out of bounds + let axis_len = self.len_of(axis); + if kth >= axis_len { + panic!("Partition index {} is out of bounds for axis {} of length {}", kth, axis.0, axis_len); + } + // Check if the first lane is contiguous let is_contiguous = result .lanes_mut(axis) @@ -3428,11 +3416,7 @@ mod tests fn test_partition_contiguous_or_not() { // Test contiguous case (C-order) - let a = array![ - [7, 1, 5], - [2, 6, 0], - [3, 4, 8] - ]; + let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; // Partition along axis 0 (contiguous) let p_axis0 = a.partition(1, Axis(0)); @@ -3442,20 +3426,24 @@ mod tests // - Last row should be >= middle row (kth element) for col in 0..3 { let kth = p_axis0[[1, col]]; - assert!(p_axis0[[0, col]] <= kth, + assert!( + p_axis0[[0, col]] <= kth, "Column {}: First row {} should be <= middle row {}", - col, p_axis0[[0, col]], kth); - assert!(p_axis0[[2, col]] >= kth, + col, + p_axis0[[0, col]], + kth + ); + assert!( + p_axis0[[2, col]] >= kth, "Column {}: Last row {} should be >= middle row {}", - col, p_axis0[[2, col]], kth); + col, + p_axis0[[2, col]], + kth + ); } // Test non-contiguous case (F-order) - let a = array![ - [7, 1, 5], - [2, 6, 0], - [3, 4, 8] - ]; + let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; // Make array non-contiguous by transposing let a = a.t().to_owned(); @@ -3467,12 +3455,69 @@ mod tests // - First column should be <= middle column // - Last column should be >= middle column for row in 0..3 { - assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]], + assert!( + p_axis1[[row, 0]] <= p_axis1[[row, 1]], "Row {}: First column {} should be <= middle column {}", - row, p_axis1[[row, 0]], p_axis1[[row, 1]]); - assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]], + row, + p_axis1[[row, 0]], + p_axis1[[row, 1]] + ); + assert!( + p_axis1[[row, 2]] >= p_axis1[[row, 1]], "Row {}: Last column {} should be >= middle column {}", - row, p_axis1[[row, 2]], p_axis1[[row, 1]]); + row, + p_axis1[[row, 2]], + p_axis1[[row, 1]] + ); } } + + #[test] + fn test_partition_empty() + { + // Test 1D empty array + let empty1d = Array1::::zeros(0); + let result1d = empty1d.partition(0, Axis(0)); + assert_eq!(result1d.len(), 0); + + // Test 1D empty array with kth out of bounds + let result1d_out_of_bounds = empty1d.partition(1, Axis(0)); + assert_eq!(result1d_out_of_bounds.len(), 0); + + // Test 2D empty array + let empty2d = Array2::::zeros((0, 3)); + let result2d = empty2d.partition(0, Axis(0)); + assert_eq!(result2d.shape(), &[0, 3]); + + // Test 2D empty array with zero columns + let empty2d_cols = Array2::::zeros((2, 0)); + let result2d_cols = empty2d_cols.partition(0, Axis(1)); + assert_eq!(result2d_cols.shape(), &[2, 0]); + + // Test 3D empty array + let empty3d = Array3::::zeros((0, 2, 3)); + let result3d = empty3d.partition(0, Axis(0)); + assert_eq!(result3d.shape(), &[0, 2, 3]); + + // Test 3D empty array with zero in middle dimension + let empty3d_mid = Array3::::zeros((2, 0, 3)); + let result3d_mid = empty3d_mid.partition(0, Axis(1)); + assert_eq!(result3d_mid.shape(), &[2, 0, 3]); + + // Test 4D empty array + let empty4d = Array4::::zeros((0, 2, 3, 4)); + let result4d = empty4d.partition(0, Axis(0)); + assert_eq!(result4d.shape(), &[0, 2, 3, 4]); + + // Test empty array with non-zero dimensions in other axes + let empty_mixed = Array2::::zeros((0, 5)); + let result_mixed = empty_mixed.partition(0, Axis(0)); + assert_eq!(result_mixed.shape(), &[0, 5]); + + // Test empty array with negative strides + let arr = Array2::::zeros((3, 3)); + let empty_slice = arr.slice(s![0..0, ..]); + let result_slice = empty_slice.partition(0, Axis(0)); + assert_eq!(result_slice.shape(), &[0, 3]); + } }