Skip to content

Commit 2324d2a

Browse files
Add partition(similar to numpy.partition) (#1498)
* fn partition
1 parent 4e2a70f commit 2324d2a

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed

Diff for: src/impl_methods.rs

+192
Original file line numberDiff line numberDiff line change
@@ -3184,6 +3184,81 @@ impl<A, D: Dimension> ArrayRef<A, D>
31843184
f(&*prev, &mut *curr)
31853185
});
31863186
}
3187+
3188+
/// Return a partitioned copy of the array.
3189+
///
3190+
/// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
3191+
/// The k-th element will be in its sorted position, with:
3192+
/// - All elements smaller than the k-th element to its left
3193+
/// - All elements equal or greater than the k-th element to its right
3194+
/// - The ordering within each partition is undefined
3195+
///
3196+
/// # Parameters
3197+
///
3198+
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199+
/// * `axis` - Axis along which to partition.
3200+
///
3201+
/// # Returns
3202+
///
3203+
/// A new array of the same shape and type as the input array, with elements partitioned.
3204+
///
3205+
/// # Examples
3206+
///
3207+
/// ```
3208+
/// use ndarray::prelude::*;
3209+
///
3210+
/// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3211+
/// let p = a.partition(3, Axis(0));
3212+
///
3213+
/// // The element at position 3 is now 3, with smaller elements to the left
3214+
/// // and greater elements to the right
3215+
/// assert_eq!(p[3], 3);
3216+
/// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
3217+
/// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
3218+
/// ```
3219+
pub fn partition(&self, kth: usize, axis: Axis) -> Array<A, D>
3220+
where
3221+
A: Clone + Ord + num_traits::Zero,
3222+
D: Dimension,
3223+
{
3224+
// Bounds checking
3225+
let axis_len = self.len_of(axis);
3226+
if kth >= axis_len {
3227+
panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len);
3228+
}
3229+
3230+
let mut result = self.to_owned();
3231+
3232+
// Check if the first lane is contiguous
3233+
let is_contiguous = result
3234+
.lanes_mut(axis)
3235+
.into_iter()
3236+
.next()
3237+
.unwrap()
3238+
.is_contiguous();
3239+
3240+
if is_contiguous {
3241+
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3242+
lane.as_slice_mut().unwrap().select_nth_unstable(kth);
3243+
});
3244+
} else {
3245+
let mut temp_vec = vec![A::zero(); axis_len];
3246+
3247+
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3248+
Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| {
3249+
*dest = src.clone();
3250+
});
3251+
3252+
temp_vec.select_nth_unstable(kth);
3253+
3254+
Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| {
3255+
*dest = src.clone();
3256+
});
3257+
});
3258+
}
3259+
3260+
result
3261+
}
31873262
}
31883263

31893264
/// Transmute from A to B.
@@ -3277,4 +3352,121 @@ mod tests
32773352
let _a2 = a.clone();
32783353
assert_first!(a);
32793354
}
3355+
3356+
#[test]
3357+
fn test_partition_1d()
3358+
{
3359+
// Test partitioning a 1D array
3360+
let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]);
3361+
let result = array.partition(3, Axis(0));
3362+
// After partitioning, the element at index 3 should be in its final sorted position
3363+
assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3]));
3364+
assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3]));
3365+
}
3366+
3367+
#[test]
3368+
fn test_partition_2d()
3369+
{
3370+
// Test partitioning a 2D array along both axes
3371+
let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]);
3372+
3373+
// Partition along axis 0 (rows)
3374+
let result0 = array.partition(1, Axis(0));
3375+
// After partitioning along axis 0, each column should have its middle element in the correct position
3376+
assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]);
3377+
assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]);
3378+
assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]);
3379+
3380+
// Partition along axis 1 (columns)
3381+
let result1 = array.partition(1, Axis(1));
3382+
// After partitioning along axis 1, each row should have its middle element in the correct position
3383+
assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]);
3384+
assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]);
3385+
assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]);
3386+
}
3387+
3388+
#[test]
3389+
fn test_partition_3d()
3390+
{
3391+
// Test partitioning a 3D array
3392+
let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]);
3393+
3394+
// Partition along axis 0
3395+
let result = array.partition(0, Axis(0));
3396+
// After partitioning, each 2x2 slice should have its first element in the correct position
3397+
assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]);
3398+
assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]);
3399+
assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]);
3400+
assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]);
3401+
}
3402+
3403+
#[test]
3404+
#[should_panic]
3405+
fn test_partition_invalid_kth()
3406+
{
3407+
let a = array![1, 2, 3, 4];
3408+
// This should panic because kth=4 is out of bounds
3409+
let _ = a.partition(4, Axis(0));
3410+
}
3411+
3412+
#[test]
3413+
#[should_panic]
3414+
fn test_partition_invalid_axis()
3415+
{
3416+
let a = array![1, 2, 3, 4];
3417+
// This should panic because axis=1 is out of bounds for a 1D array
3418+
let _ = a.partition(0, Axis(1));
3419+
}
3420+
3421+
#[test]
3422+
fn test_partition_contiguous_or_not()
3423+
{
3424+
// Test contiguous case (C-order)
3425+
let a = array![
3426+
[7, 1, 5],
3427+
[2, 6, 0],
3428+
[3, 4, 8]
3429+
];
3430+
3431+
// Partition along axis 0 (contiguous)
3432+
let p_axis0 = a.partition(1, Axis(0));
3433+
3434+
// For each column, verify the partitioning:
3435+
// - First row should be <= middle row (kth element)
3436+
// - Last row should be >= middle row (kth element)
3437+
for col in 0..3 {
3438+
let kth = p_axis0[[1, col]];
3439+
assert!(p_axis0[[0, col]] <= kth,
3440+
"Column {}: First row {} should be <= middle row {}",
3441+
col, p_axis0[[0, col]], kth);
3442+
assert!(p_axis0[[2, col]] >= kth,
3443+
"Column {}: Last row {} should be >= middle row {}",
3444+
col, p_axis0[[2, col]], kth);
3445+
}
3446+
3447+
// Test non-contiguous case (F-order)
3448+
let a = array![
3449+
[7, 1, 5],
3450+
[2, 6, 0],
3451+
[3, 4, 8]
3452+
];
3453+
3454+
// Make array non-contiguous by transposing
3455+
let a = a.t().to_owned();
3456+
3457+
// Partition along axis 1 (non-contiguous)
3458+
let p_axis1 = a.partition(1, Axis(1));
3459+
3460+
// For each row, verify the partitioning:
3461+
// - First column should be <= middle column
3462+
// - Last column should be >= middle column
3463+
for row in 0..3 {
3464+
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]],
3465+
"Row {}: First column {} should be <= middle column {}",
3466+
row, p_axis1[[row, 0]], p_axis1[[row, 1]]);
3467+
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]],
3468+
"Row {}: Last column {} should be >= middle column {}",
3469+
row, p_axis1[[row, 2]], p_axis1[[row, 1]]);
3470+
}
3471+
}
32803472
}

0 commit comments

Comments
 (0)