Skip to content

Add partition(similar to numpy.partition) #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

192 changes: 192 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3184,6 +3184,81 @@ impl<A, D: Dimension> ArrayRef<A, D>
f(&*prev, &mut *curr)
});
}

/// Return a partitioned copy of the array.
///
/// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
/// The k-th element will be in its sorted position, with:
/// - All elements smaller than the k-th element to its left
/// - All elements equal or greater than the k-th element to its right
/// - The ordering within each partition is undefined
///
/// # Parameters
///
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
/// * `axis` - Axis along which to partition.
///
/// # Returns
///
/// A new array of the same shape and type as the input array, with elements partitioned.
///
/// # Examples
///
/// ```
/// use ndarray::prelude::*;
///
/// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
/// let p = a.partition(3, Axis(0));
///
/// // The element at position 3 is now 3, with smaller elements to the left
/// // and greater elements to the right
/// assert_eq!(p[3], 3);
/// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
/// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
/// ```
pub fn partition(&self, kth: usize, axis: Axis) -> Array<A, D>
where
A: Clone + Ord + num_traits::Zero,
D: Dimension,
{
// Bounds checking
let axis_len = self.len_of(axis);
if kth >= axis_len {
panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len);
}

let mut result = self.to_owned();

// Check if the first lane is contiguous
let is_contiguous = result
.lanes_mut(axis)
.into_iter()
.next()
.unwrap()
.is_contiguous();

if is_contiguous {
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
lane.as_slice_mut().unwrap().select_nth_unstable(kth);
});
} else {
let mut temp_vec = vec![A::zero(); axis_len];

Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| {
*dest = src.clone();
});

temp_vec.select_nth_unstable(kth);

Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| {
*dest = src.clone();
});
});
}

result
}
}

/// Transmute from A to B.
Expand Down Expand Up @@ -3277,4 +3352,121 @@ mod tests
let _a2 = a.clone();
assert_first!(a);
}

#[test]
fn test_partition_1d()
{
// Test partitioning a 1D array
let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]);
let result = array.partition(3, Axis(0));
// After partitioning, the element at index 3 should be in its final sorted position
assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3]));
assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3]));
}

#[test]
fn test_partition_2d()
{
// Test partitioning a 2D array along both axes
let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]);

// Partition along axis 0 (rows)
let result0 = array.partition(1, Axis(0));
// After partitioning along axis 0, each column should have its middle element in the correct position
assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]);
assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]);
assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]);

// Partition along axis 1 (columns)
let result1 = array.partition(1, Axis(1));
// After partitioning along axis 1, each row should have its middle element in the correct position
assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]);
assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]);
assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]);
}

#[test]
fn test_partition_3d()
{
// Test partitioning a 3D array
let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]);

// Partition along axis 0
let result = array.partition(0, Axis(0));
// After partitioning, each 2x2 slice should have its first element in the correct position
assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]);
assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]);
assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]);
assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]);
}

#[test]
#[should_panic]
fn test_partition_invalid_kth()
{
let a = array![1, 2, 3, 4];
// This should panic because kth=4 is out of bounds
let _ = a.partition(4, Axis(0));
}

#[test]
#[should_panic]
fn test_partition_invalid_axis()
{
let a = array![1, 2, 3, 4];
// This should panic because axis=1 is out of bounds for a 1D array
let _ = a.partition(0, Axis(1));
}

#[test]
fn test_partition_contiguous_or_not()
{
// Test contiguous case (C-order)
let a = array![
[7, 1, 5],
[2, 6, 0],
[3, 4, 8]
];

// Partition along axis 0 (contiguous)
let p_axis0 = a.partition(1, Axis(0));

// For each column, verify the partitioning:
// - First row should be <= middle row (kth element)
// - Last row should be >= middle row (kth element)
for col in 0..3 {
let kth = p_axis0[[1, col]];
assert!(p_axis0[[0, col]] <= kth,
"Column {}: First row {} should be <= middle row {}",
col, p_axis0[[0, col]], kth);
assert!(p_axis0[[2, col]] >= kth,
"Column {}: Last row {} should be >= middle row {}",
col, p_axis0[[2, col]], kth);
}

// Test non-contiguous case (F-order)
let a = array![
[7, 1, 5],
[2, 6, 0],
[3, 4, 8]
];

// Make array non-contiguous by transposing
let a = a.t().to_owned();

// Partition along axis 1 (non-contiguous)
let p_axis1 = a.partition(1, Axis(1));

// For each row, verify the partitioning:
// - First column should be <= middle column
// - Last column should be >= middle column
for row in 0..3 {
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]],
"Row {}: First column {} should be <= middle column {}",
row, p_axis1[[row, 0]], p_axis1[[row, 1]]);
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]],
"Row {}: Last column {} should be >= middle column {}",
row, p_axis1[[row, 2]], p_axis1[[row, 1]]);
}
}
}