@@ -3184,6 +3184,81 @@ impl<A, D: Dimension> ArrayRef<A, D>
3184
3184
f ( & * prev, & mut * curr)
3185
3185
} ) ;
3186
3186
}
3187
+
3188
+ /// Return a partitioned copy of the array.
3189
+ ///
3190
+ /// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
3191
+ /// The k-th element will be in its sorted position, with:
3192
+ /// - All elements smaller than the k-th element to its left
3193
+ /// - All elements equal or greater than the k-th element to its right
3194
+ /// - The ordering within each partition is undefined
3195
+ ///
3196
+ /// # Parameters
3197
+ ///
3198
+ /// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199
+ /// * `axis` - Axis along which to partition.
3200
+ ///
3201
+ /// # Returns
3202
+ ///
3203
+ /// A new array of the same shape and type as the input array, with elements partitioned.
3204
+ ///
3205
+ /// # Examples
3206
+ ///
3207
+ /// ```
3208
+ /// use ndarray::prelude::*;
3209
+ ///
3210
+ /// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3211
+ /// let p = a.partition(3, Axis(0));
3212
+ ///
3213
+ /// // The element at position 3 is now 3, with smaller elements to the left
3214
+ /// // and greater elements to the right
3215
+ /// assert_eq!(p[3], 3);
3216
+ /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
3217
+ /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
3218
+ /// ```
3219
+ pub fn partition ( & self , kth : usize , axis : Axis ) -> Array < A , D >
3220
+ where
3221
+ A : Clone + Ord + num_traits:: Zero ,
3222
+ D : Dimension ,
3223
+ {
3224
+ // Bounds checking
3225
+ let axis_len = self . len_of ( axis) ;
3226
+ if kth >= axis_len {
3227
+ panic ! ( "partition index {} is out of bounds for axis of length {}" , kth, axis_len) ;
3228
+ }
3229
+
3230
+ let mut result = self . to_owned ( ) ;
3231
+
3232
+ // Check if the first lane is contiguous
3233
+ let is_contiguous = result
3234
+ . lanes_mut ( axis)
3235
+ . into_iter ( )
3236
+ . next ( )
3237
+ . unwrap ( )
3238
+ . is_contiguous ( ) ;
3239
+
3240
+ if is_contiguous {
3241
+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3242
+ lane. as_slice_mut ( ) . unwrap ( ) . select_nth_unstable ( kth) ;
3243
+ } ) ;
3244
+ } else {
3245
+ let mut temp_vec = vec ! [ A :: zero( ) ; axis_len] ;
3246
+
3247
+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3248
+ Zip :: from ( & mut temp_vec) . and ( & lane) . for_each ( |dest, src| {
3249
+ * dest = src. clone ( ) ;
3250
+ } ) ;
3251
+
3252
+ temp_vec. select_nth_unstable ( kth) ;
3253
+
3254
+ Zip :: from ( & mut lane) . and ( & temp_vec) . for_each ( |dest, src| {
3255
+ * dest = src. clone ( ) ;
3256
+ } ) ;
3257
+ } ) ;
3258
+ }
3259
+
3260
+ result
3261
+ }
3187
3262
}
3188
3263
3189
3264
/// Transmute from A to B.
@@ -3277,4 +3352,121 @@ mod tests
3277
3352
let _a2 = a. clone ( ) ;
3278
3353
assert_first ! ( a) ;
3279
3354
}
3355
+
3356
+ #[ test]
3357
+ fn test_partition_1d ( )
3358
+ {
3359
+ // Test partitioning a 1D array
3360
+ let array = arr1 ( & [ 3 , 1 , 4 , 1 , 5 , 9 , 2 , 6 ] ) ;
3361
+ let result = array. partition ( 3 , Axis ( 0 ) ) ;
3362
+ // After partitioning, the element at index 3 should be in its final sorted position
3363
+ assert ! ( result. slice( s![ ..3 ] ) . iter( ) . all( |& x| x <= result[ 3 ] ) ) ;
3364
+ assert ! ( result. slice( s![ 4 ..] ) . iter( ) . all( |& x| x >= result[ 3 ] ) ) ;
3365
+ }
3366
+
3367
+ #[ test]
3368
+ fn test_partition_2d ( )
3369
+ {
3370
+ // Test partitioning a 2D array along both axes
3371
+ let array = arr2 ( & [ [ 3 , 1 , 4 ] , [ 1 , 5 , 9 ] , [ 2 , 6 , 5 ] ] ) ;
3372
+
3373
+ // Partition along axis 0 (rows)
3374
+ let result0 = array. partition ( 1 , Axis ( 0 ) ) ;
3375
+ // After partitioning along axis 0, each column should have its middle element in the correct position
3376
+ assert ! ( result0[ [ 0 , 0 ] ] <= result0[ [ 1 , 0 ] ] && result0[ [ 2 , 0 ] ] >= result0[ [ 1 , 0 ] ] ) ;
3377
+ assert ! ( result0[ [ 0 , 1 ] ] <= result0[ [ 1 , 1 ] ] && result0[ [ 2 , 1 ] ] >= result0[ [ 1 , 1 ] ] ) ;
3378
+ assert ! ( result0[ [ 0 , 2 ] ] <= result0[ [ 1 , 2 ] ] && result0[ [ 2 , 2 ] ] >= result0[ [ 1 , 2 ] ] ) ;
3379
+
3380
+ // Partition along axis 1 (columns)
3381
+ let result1 = array. partition ( 1 , Axis ( 1 ) ) ;
3382
+ // After partitioning along axis 1, each row should have its middle element in the correct position
3383
+ assert ! ( result1[ [ 0 , 0 ] ] <= result1[ [ 0 , 1 ] ] && result1[ [ 0 , 2 ] ] >= result1[ [ 0 , 1 ] ] ) ;
3384
+ assert ! ( result1[ [ 1 , 0 ] ] <= result1[ [ 1 , 1 ] ] && result1[ [ 1 , 2 ] ] >= result1[ [ 1 , 1 ] ] ) ;
3385
+ assert ! ( result1[ [ 2 , 0 ] ] <= result1[ [ 2 , 1 ] ] && result1[ [ 2 , 2 ] ] >= result1[ [ 2 , 1 ] ] ) ;
3386
+ }
3387
+
3388
+ #[ test]
3389
+ fn test_partition_3d ( )
3390
+ {
3391
+ // Test partitioning a 3D array
3392
+ let array = arr3 ( & [ [ [ 3 , 1 ] , [ 4 , 1 ] ] , [ [ 5 , 9 ] , [ 2 , 6 ] ] ] ) ;
3393
+
3394
+ // Partition along axis 0
3395
+ let result = array. partition ( 0 , Axis ( 0 ) ) ;
3396
+ // After partitioning, each 2x2 slice should have its first element in the correct position
3397
+ assert ! ( result[ [ 0 , 0 , 0 ] ] <= result[ [ 1 , 0 , 0 ] ] ) ;
3398
+ assert ! ( result[ [ 0 , 0 , 1 ] ] <= result[ [ 1 , 0 , 1 ] ] ) ;
3399
+ assert ! ( result[ [ 0 , 1 , 0 ] ] <= result[ [ 1 , 1 , 0 ] ] ) ;
3400
+ assert ! ( result[ [ 0 , 1 , 1 ] ] <= result[ [ 1 , 1 , 1 ] ] ) ;
3401
+ }
3402
+
3403
+ #[ test]
3404
+ #[ should_panic]
3405
+ fn test_partition_invalid_kth ( )
3406
+ {
3407
+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3408
+ // This should panic because kth=4 is out of bounds
3409
+ let _ = a. partition ( 4 , Axis ( 0 ) ) ;
3410
+ }
3411
+
3412
+ #[ test]
3413
+ #[ should_panic]
3414
+ fn test_partition_invalid_axis ( )
3415
+ {
3416
+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3417
+ // This should panic because axis=1 is out of bounds for a 1D array
3418
+ let _ = a. partition ( 0 , Axis ( 1 ) ) ;
3419
+ }
3420
+
3421
+ #[ test]
3422
+ fn test_partition_contiguous_or_not ( )
3423
+ {
3424
+ // Test contiguous case (C-order)
3425
+ let a = array ! [
3426
+ [ 7 , 1 , 5 ] ,
3427
+ [ 2 , 6 , 0 ] ,
3428
+ [ 3 , 4 , 8 ]
3429
+ ] ;
3430
+
3431
+ // Partition along axis 0 (contiguous)
3432
+ let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
3433
+
3434
+ // For each column, verify the partitioning:
3435
+ // - First row should be <= middle row (kth element)
3436
+ // - Last row should be >= middle row (kth element)
3437
+ for col in 0 ..3 {
3438
+ let kth = p_axis0[ [ 1 , col] ] ;
3439
+ assert ! ( p_axis0[ [ 0 , col] ] <= kth,
3440
+ "Column {}: First row {} should be <= middle row {}" ,
3441
+ col, p_axis0[ [ 0 , col] ] , kth) ;
3442
+ assert ! ( p_axis0[ [ 2 , col] ] >= kth,
3443
+ "Column {}: Last row {} should be >= middle row {}" ,
3444
+ col, p_axis0[ [ 2 , col] ] , kth) ;
3445
+ }
3446
+
3447
+ // Test non-contiguous case (F-order)
3448
+ let a = array ! [
3449
+ [ 7 , 1 , 5 ] ,
3450
+ [ 2 , 6 , 0 ] ,
3451
+ [ 3 , 4 , 8 ]
3452
+ ] ;
3453
+
3454
+ // Make array non-contiguous by transposing
3455
+ let a = a. t ( ) . to_owned ( ) ;
3456
+
3457
+ // Partition along axis 1 (non-contiguous)
3458
+ let p_axis1 = a. partition ( 1 , Axis ( 1 ) ) ;
3459
+
3460
+ // For each row, verify the partitioning:
3461
+ // - First column should be <= middle column
3462
+ // - Last column should be >= middle column
3463
+ for row in 0 ..3 {
3464
+ assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3465
+ "Row {}: First column {} should be <= middle column {}" ,
3466
+ row, p_axis1[ [ row, 0 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3467
+ assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3468
+ "Row {}: Last column {} should be >= middle column {}" ,
3469
+ row, p_axis1[ [ row, 2 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3470
+ }
3471
+ }
3280
3472
}
0 commit comments