diff --git a/src/impl_methods.rs b/src/impl_methods.rs
index d2f04ef1f..42d843781 100644
--- a/src/impl_methods.rs
+++ b/src/impl_methods.rs
@@ -3184,6 +3184,81 @@ impl ArrayRef
f(&*prev, &mut *curr)
});
}
+
+ /// Return a partitioned copy of the array.
+ ///
+ /// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
+ /// The k-th element will be in its sorted position, with:
+ /// - All elements smaller than the k-th element to its left
+ /// - All elements equal or greater than the k-th element to its right
+ /// - The ordering within each partition is undefined
+ ///
+ /// # Parameters
+ ///
+ /// * `kth` - Index to partition by. The k-th element will be in its sorted position.
+ /// * `axis` - Axis along which to partition.
+ ///
+ /// # Returns
+ ///
+ /// A new array of the same shape and type as the input array, with elements partitioned.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use ndarray::prelude::*;
+ ///
+ /// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
+ /// let p = a.partition(3, Axis(0));
+ ///
+ /// // The element at position 3 is now 3, with smaller elements to the left
+ /// // and greater elements to the right
+ /// assert_eq!(p[3], 3);
+ /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
+ /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
+ /// ```
+ pub fn partition(&self, kth: usize, axis: Axis) -> Array
+ where
+ A: Clone + Ord + num_traits::Zero,
+ D: Dimension,
+ {
+ // Bounds checking
+ let axis_len = self.len_of(axis);
+ if kth >= axis_len {
+ panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len);
+ }
+
+ let mut result = self.to_owned();
+
+ // Check if the first lane is contiguous
+ let is_contiguous = result
+ .lanes_mut(axis)
+ .into_iter()
+ .next()
+ .unwrap()
+ .is_contiguous();
+
+ if is_contiguous {
+ Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
+ lane.as_slice_mut().unwrap().select_nth_unstable(kth);
+ });
+ } else {
+ let mut temp_vec = vec![A::zero(); axis_len];
+
+ Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
+ Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| {
+ *dest = src.clone();
+ });
+
+ temp_vec.select_nth_unstable(kth);
+
+ Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| {
+ *dest = src.clone();
+ });
+ });
+ }
+
+ result
+ }
}
/// Transmute from A to B.
@@ -3277,4 +3352,121 @@ mod tests
let _a2 = a.clone();
assert_first!(a);
}
+
+ #[test]
+ fn test_partition_1d()
+ {
+ // Test partitioning a 1D array
+ let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]);
+ let result = array.partition(3, Axis(0));
+ // After partitioning, the element at index 3 should be in its final sorted position
+ assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3]));
+ assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3]));
+ }
+
+ #[test]
+ fn test_partition_2d()
+ {
+ // Test partitioning a 2D array along both axes
+ let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]);
+
+ // Partition along axis 0 (rows)
+ let result0 = array.partition(1, Axis(0));
+ // After partitioning along axis 0, each column should have its middle element in the correct position
+ assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]);
+ assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]);
+ assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]);
+
+ // Partition along axis 1 (columns)
+ let result1 = array.partition(1, Axis(1));
+ // After partitioning along axis 1, each row should have its middle element in the correct position
+ assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]);
+ assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]);
+ assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]);
+ }
+
+ #[test]
+ fn test_partition_3d()
+ {
+ // Test partitioning a 3D array
+ let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]);
+
+ // Partition along axis 0
+ let result = array.partition(0, Axis(0));
+ // After partitioning, each 2x2 slice should have its first element in the correct position
+ assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]);
+ assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]);
+ assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]);
+ assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]);
+ }
+
+ #[test]
+ #[should_panic]
+ fn test_partition_invalid_kth()
+ {
+ let a = array![1, 2, 3, 4];
+ // This should panic because kth=4 is out of bounds
+ let _ = a.partition(4, Axis(0));
+ }
+
+ #[test]
+ #[should_panic]
+ fn test_partition_invalid_axis()
+ {
+ let a = array![1, 2, 3, 4];
+ // This should panic because axis=1 is out of bounds for a 1D array
+ let _ = a.partition(0, Axis(1));
+ }
+
+ #[test]
+ fn test_partition_contiguous_or_not()
+ {
+ // Test contiguous case (C-order)
+ let a = array![
+ [7, 1, 5],
+ [2, 6, 0],
+ [3, 4, 8]
+ ];
+
+ // Partition along axis 0 (contiguous)
+ let p_axis0 = a.partition(1, Axis(0));
+
+ // For each column, verify the partitioning:
+ // - First row should be <= middle row (kth element)
+ // - Last row should be >= middle row (kth element)
+ for col in 0..3 {
+ let kth = p_axis0[[1, col]];
+ assert!(p_axis0[[0, col]] <= kth,
+ "Column {}: First row {} should be <= middle row {}",
+ col, p_axis0[[0, col]], kth);
+ assert!(p_axis0[[2, col]] >= kth,
+ "Column {}: Last row {} should be >= middle row {}",
+ col, p_axis0[[2, col]], kth);
+ }
+
+ // Test non-contiguous case (F-order)
+ let a = array![
+ [7, 1, 5],
+ [2, 6, 0],
+ [3, 4, 8]
+ ];
+
+ // Make array non-contiguous by transposing
+ let a = a.t().to_owned();
+
+ // Partition along axis 1 (non-contiguous)
+ let p_axis1 = a.partition(1, Axis(1));
+
+ // For each row, verify the partitioning:
+ // - First column should be <= middle column
+ // - Last column should be >= middle column
+ for row in 0..3 {
+ assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]],
+ "Row {}: First column {} should be <= middle column {}",
+ row, p_axis1[[row, 0]], p_axis1[[row, 1]]);
+ assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]],
+ "Row {}: Last column {} should be >= middle column {}",
+ row, p_axis1[[row, 2]], p_axis1[[row, 1]]);
+ }
+ }
}