diff --git a/src/impl_methods.rs b/src/impl_methods.rs index d2f04ef1f..42d843781 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3184,6 +3184,81 @@ impl ArrayRef f(&*prev, &mut *curr) }); } + + /// Return a partitioned copy of the array. + /// + /// Creates a copy of the array and partially sorts it around the k-th element along the given axis. + /// The k-th element will be in its sorted position, with: + /// - All elements smaller than the k-th element to its left + /// - All elements equal or greater than the k-th element to its right + /// - The ordering within each partition is undefined + /// + /// # Parameters + /// + /// * `kth` - Index to partition by. The k-th element will be in its sorted position. + /// * `axis` - Axis along which to partition. + /// + /// # Returns + /// + /// A new array of the same shape and type as the input array, with elements partitioned. + /// + /// # Examples + /// + /// ``` + /// use ndarray::prelude::*; + /// + /// let a = array![7, 1, 5, 2, 6, 0, 3, 4]; + /// let p = a.partition(3, Axis(0)); + /// + /// // The element at position 3 is now 3, with smaller elements to the left + /// // and greater elements to the right + /// assert_eq!(p[3], 3); + /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3)); + /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3)); + /// ``` + pub fn partition(&self, kth: usize, axis: Axis) -> Array + where + A: Clone + Ord + num_traits::Zero, + D: Dimension, + { + // Bounds checking + let axis_len = self.len_of(axis); + if kth >= axis_len { + panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len); + } + + let mut result = self.to_owned(); + + // Check if the first lane is contiguous + let is_contiguous = result + .lanes_mut(axis) + .into_iter() + .next() + .unwrap() + .is_contiguous(); + + if is_contiguous { + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + lane.as_slice_mut().unwrap().select_nth_unstable(kth); + }); + } else { + let mut temp_vec = vec![A::zero(); axis_len]; + + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| { + *dest = src.clone(); + }); + + temp_vec.select_nth_unstable(kth); + + Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| { + *dest = src.clone(); + }); + }); + } + + result + } } /// Transmute from A to B. @@ -3277,4 +3352,121 @@ mod tests let _a2 = a.clone(); assert_first!(a); } + + #[test] + fn test_partition_1d() + { + // Test partitioning a 1D array + let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]); + let result = array.partition(3, Axis(0)); + // After partitioning, the element at index 3 should be in its final sorted position + assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3])); + assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3])); + } + + #[test] + fn test_partition_2d() + { + // Test partitioning a 2D array along both axes + let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]); + + // Partition along axis 0 (rows) + let result0 = array.partition(1, Axis(0)); + // After partitioning along axis 0, each column should have its middle element in the correct position + assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]); + assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]); + assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]); + + // Partition along axis 1 (columns) + let result1 = array.partition(1, Axis(1)); + // After partitioning along axis 1, each row should have its middle element in the correct position + assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]); + assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]); + assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]); + } + + #[test] + fn test_partition_3d() + { + // Test partitioning a 3D array + let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]); + + // Partition along axis 0 + let result = array.partition(0, Axis(0)); + // After partitioning, each 2x2 slice should have its first element in the correct position + assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]); + assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]); + assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]); + assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]); + } + + #[test] + #[should_panic] + fn test_partition_invalid_kth() + { + let a = array![1, 2, 3, 4]; + // This should panic because kth=4 is out of bounds + let _ = a.partition(4, Axis(0)); + } + + #[test] + #[should_panic] + fn test_partition_invalid_axis() + { + let a = array![1, 2, 3, 4]; + // This should panic because axis=1 is out of bounds for a 1D array + let _ = a.partition(0, Axis(1)); + } + + #[test] + fn test_partition_contiguous_or_not() + { + // Test contiguous case (C-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Partition along axis 0 (contiguous) + let p_axis0 = a.partition(1, Axis(0)); + + // For each column, verify the partitioning: + // - First row should be <= middle row (kth element) + // - Last row should be >= middle row (kth element) + for col in 0..3 { + let kth = p_axis0[[1, col]]; + assert!(p_axis0[[0, col]] <= kth, + "Column {}: First row {} should be <= middle row {}", + col, p_axis0[[0, col]], kth); + assert!(p_axis0[[2, col]] >= kth, + "Column {}: Last row {} should be >= middle row {}", + col, p_axis0[[2, col]], kth); + } + + // Test non-contiguous case (F-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Make array non-contiguous by transposing + let a = a.t().to_owned(); + + // Partition along axis 1 (non-contiguous) + let p_axis1 = a.partition(1, Axis(1)); + + // For each row, verify the partitioning: + // - First column should be <= middle column + // - Last column should be >= middle column + for row in 0..3 { + assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]], + "Row {}: First column {} should be <= middle column {}", + row, p_axis1[[row, 0]], p_axis1[[row, 1]]); + assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]], + "Row {}: Last column {} should be >= middle column {}", + row, p_axis1[[row, 2]], p_axis1[[row, 1]]); + } + } }