@@ -576,11 +576,7 @@ where
576
576
pub fn slice_move < I > ( mut self , info : I ) -> ArrayBase < S , I :: OutDim >
577
577
where I : SliceArg < D >
578
578
{
579
- assert_eq ! (
580
- info. in_ndim( ) ,
581
- self . ndim( ) ,
582
- "The input dimension of `info` must match the array to be sliced." ,
583
- ) ;
579
+ assert_eq ! ( info. in_ndim( ) , self . ndim( ) , "The input dimension of `info` must match the array to be sliced." , ) ;
584
580
let out_ndim = info. out_ndim ( ) ;
585
581
let mut new_dim = I :: OutDim :: zeros ( out_ndim) ;
586
582
let mut new_strides = I :: OutDim :: zeros ( out_ndim) ;
@@ -648,11 +644,7 @@ impl<A, D: Dimension> LayoutRef<A, D>
648
644
pub fn slice_collapse < I > ( & mut self , info : I )
649
645
where I : SliceArg < D >
650
646
{
651
- assert_eq ! (
652
- info. in_ndim( ) ,
653
- self . ndim( ) ,
654
- "The input dimension of `info` must match the array to be sliced." ,
655
- ) ;
647
+ assert_eq ! ( info. in_ndim( ) , self . ndim( ) , "The input dimension of `info` must match the array to be sliced." , ) ;
656
648
let mut axis = 0 ;
657
649
info. as_ref ( ) . iter ( ) . for_each ( |& ax_info| match ax_info {
658
650
SliceInfoElem :: Slice { start, end, step } => {
@@ -1120,8 +1112,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
1120
1112
// bounds check the indices first
1121
1113
if let Some ( max_index) = indices. iter ( ) . cloned ( ) . max ( ) {
1122
1114
if max_index >= axis_len {
1123
- panic ! ( "ndarray: index {} is out of bounds in array of len {}" ,
1124
- max_index, self . len_of( axis) ) ;
1115
+ panic ! ( "ndarray: index {} is out of bounds in array of len {}" , max_index, self . len_of( axis) ) ;
1125
1116
}
1126
1117
} // else: indices empty is ok
1127
1118
let view = self . view ( ) . into_dimensionality :: < Ix1 > ( ) . unwrap ( ) ;
@@ -1530,10 +1521,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
1530
1521
1531
1522
ndassert ! (
1532
1523
axis_index < self . ndim( ) ,
1533
- concat!(
1534
- "Window axis {} does not match array dimension {} " ,
1535
- "(with array of shape {:?})"
1536
- ) ,
1524
+ concat!( "Window axis {} does not match array dimension {} " , "(with array of shape {:?})" ) ,
1537
1525
axis_index,
1538
1526
self . ndim( ) ,
1539
1527
self . shape( )
@@ -3119,8 +3107,7 @@ where
3119
3107
/// ***Panics*** if not `index < self.len_of(axis)`.
3120
3108
pub fn remove_index ( & mut self , axis : Axis , index : usize )
3121
3109
{
3122
- assert ! ( index < self . len_of( axis) , "index {} must be less than length of Axis({})" ,
3123
- index, axis. index( ) ) ;
3110
+ assert ! ( index < self . len_of( axis) , "index {} must be less than length of Axis({})" , index, axis. index( ) ) ;
3124
3111
let ( _, mut tail) = self . view_mut ( ) . split_at ( axis, index) ;
3125
3112
// shift elements to the front
3126
3113
Zip :: from ( tail. lanes_mut ( axis) ) . for_each ( |mut lane| lane. rotate1_front ( ) ) ;
@@ -3193,15 +3180,16 @@ impl<A, D: Dimension> ArrayRef<A, D>
3193
3180
/// - All elements equal or greater than the k-th element to its right
3194
3181
/// - The ordering within each partition is undefined
3195
3182
///
3183
+ /// Empty arrays (i.e., those with any zero-length axes) are considered partitioned already,
3184
+ /// and will be returned unchanged.
3185
+ ///
3186
+ /// **Panics** if `k` is out of bounds for a non-zero axis length.
3187
+ ///
3196
3188
/// # Parameters
3197
3189
///
3198
3190
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199
3191
/// * `axis` - Axis along which to partition.
3200
3192
///
3201
- /// # Returns
3202
- ///
3203
- /// A new array of the same shape and type as the input array, with elements partitioned.
3204
- ///
3205
3193
/// # Examples
3206
3194
///
3207
3195
/// ```
@@ -3221,19 +3209,19 @@ impl<A, D: Dimension> ArrayRef<A, D>
3221
3209
A : Clone + Ord + num_traits:: Zero ,
3222
3210
D : Dimension ,
3223
3211
{
3224
- // Bounds checking
3225
- let axis_len = self . len_of ( axis) ;
3226
- if kth >= axis_len {
3227
- panic ! ( "partition index {} is out of bounds for axis of length {}" , kth, axis_len) ;
3228
- }
3229
-
3230
3212
let mut result = self . to_owned ( ) ;
3231
3213
3232
- // Must guarantee that the array isn't empty before checking for contiguity
3233
- if result . shape ( ) . iter ( ) . any ( |s| * s == 0 ) {
3214
+ // Return early if the array has zero-length dimensions
3215
+ if self . shape ( ) . iter ( ) . any ( |s| * s == 0 ) {
3234
3216
return result;
3235
3217
}
3236
3218
3219
+ // Bounds checking. Panics if kth is out of bounds
3220
+ let axis_len = self . len_of ( axis) ;
3221
+ if kth >= axis_len {
3222
+ panic ! ( "Partition index {} is out of bounds for axis {} of length {}" , kth, axis. 0 , axis_len) ;
3223
+ }
3224
+
3237
3225
// Check if the first lane is contiguous
3238
3226
let is_contiguous = result
3239
3227
. lanes_mut ( axis)
@@ -3428,11 +3416,7 @@ mod tests
3428
3416
fn test_partition_contiguous_or_not ( )
3429
3417
{
3430
3418
// Test contiguous case (C-order)
3431
- let a = array ! [
3432
- [ 7 , 1 , 5 ] ,
3433
- [ 2 , 6 , 0 ] ,
3434
- [ 3 , 4 , 8 ]
3435
- ] ;
3419
+ let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
3436
3420
3437
3421
// Partition along axis 0 (contiguous)
3438
3422
let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
@@ -3442,20 +3426,24 @@ mod tests
3442
3426
// - Last row should be >= middle row (kth element)
3443
3427
for col in 0 ..3 {
3444
3428
let kth = p_axis0[ [ 1 , col] ] ;
3445
- assert ! ( p_axis0[ [ 0 , col] ] <= kth,
3429
+ assert ! (
3430
+ p_axis0[ [ 0 , col] ] <= kth,
3446
3431
"Column {}: First row {} should be <= middle row {}" ,
3447
- col, p_axis0[ [ 0 , col] ] , kth) ;
3448
- assert ! ( p_axis0[ [ 2 , col] ] >= kth,
3432
+ col,
3433
+ p_axis0[ [ 0 , col] ] ,
3434
+ kth
3435
+ ) ;
3436
+ assert ! (
3437
+ p_axis0[ [ 2 , col] ] >= kth,
3449
3438
"Column {}: Last row {} should be >= middle row {}" ,
3450
- col, p_axis0[ [ 2 , col] ] , kth) ;
3439
+ col,
3440
+ p_axis0[ [ 2 , col] ] ,
3441
+ kth
3442
+ ) ;
3451
3443
}
3452
3444
3453
3445
// Test non-contiguous case (F-order)
3454
- let a = array ! [
3455
- [ 7 , 1 , 5 ] ,
3456
- [ 2 , 6 , 0 ] ,
3457
- [ 3 , 4 , 8 ]
3458
- ] ;
3446
+ let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
3459
3447
3460
3448
// Make array non-contiguous by transposing
3461
3449
let a = a. t ( ) . to_owned ( ) ;
@@ -3467,12 +3455,69 @@ mod tests
3467
3455
// - First column should be <= middle column
3468
3456
// - Last column should be >= middle column
3469
3457
for row in 0 ..3 {
3470
- assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3458
+ assert ! (
3459
+ p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3471
3460
"Row {}: First column {} should be <= middle column {}" ,
3472
- row, p_axis1[ [ row, 0 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3473
- assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3461
+ row,
3462
+ p_axis1[ [ row, 0 ] ] ,
3463
+ p_axis1[ [ row, 1 ] ]
3464
+ ) ;
3465
+ assert ! (
3466
+ p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3474
3467
"Row {}: Last column {} should be >= middle column {}" ,
3475
- row, p_axis1[ [ row, 2 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3468
+ row,
3469
+ p_axis1[ [ row, 2 ] ] ,
3470
+ p_axis1[ [ row, 1 ] ]
3471
+ ) ;
3476
3472
}
3477
3473
}
3474
+
3475
+ #[ test]
3476
+ fn test_partition_empty ( )
3477
+ {
3478
+ // Test 1D empty array
3479
+ let empty1d = Array1 :: < i32 > :: zeros ( 0 ) ;
3480
+ let result1d = empty1d. partition ( 0 , Axis ( 0 ) ) ;
3481
+ assert_eq ! ( result1d. len( ) , 0 ) ;
3482
+
3483
+ // Test 1D empty array with kth out of bounds
3484
+ let result1d_out_of_bounds = empty1d. partition ( 1 , Axis ( 0 ) ) ;
3485
+ assert_eq ! ( result1d_out_of_bounds. len( ) , 0 ) ;
3486
+
3487
+ // Test 2D empty array
3488
+ let empty2d = Array2 :: < i32 > :: zeros ( ( 0 , 3 ) ) ;
3489
+ let result2d = empty2d. partition ( 0 , Axis ( 0 ) ) ;
3490
+ assert_eq ! ( result2d. shape( ) , & [ 0 , 3 ] ) ;
3491
+
3492
+ // Test 2D empty array with zero columns
3493
+ let empty2d_cols = Array2 :: < i32 > :: zeros ( ( 2 , 0 ) ) ;
3494
+ let result2d_cols = empty2d_cols. partition ( 0 , Axis ( 1 ) ) ;
3495
+ assert_eq ! ( result2d_cols. shape( ) , & [ 2 , 0 ] ) ;
3496
+
3497
+ // Test 3D empty array
3498
+ let empty3d = Array3 :: < i32 > :: zeros ( ( 0 , 2 , 3 ) ) ;
3499
+ let result3d = empty3d. partition ( 0 , Axis ( 0 ) ) ;
3500
+ assert_eq ! ( result3d. shape( ) , & [ 0 , 2 , 3 ] ) ;
3501
+
3502
+ // Test 3D empty array with zero in middle dimension
3503
+ let empty3d_mid = Array3 :: < i32 > :: zeros ( ( 2 , 0 , 3 ) ) ;
3504
+ let result3d_mid = empty3d_mid. partition ( 0 , Axis ( 1 ) ) ;
3505
+ assert_eq ! ( result3d_mid. shape( ) , & [ 2 , 0 , 3 ] ) ;
3506
+
3507
+ // Test 4D empty array
3508
+ let empty4d = Array4 :: < i32 > :: zeros ( ( 0 , 2 , 3 , 4 ) ) ;
3509
+ let result4d = empty4d. partition ( 0 , Axis ( 0 ) ) ;
3510
+ assert_eq ! ( result4d. shape( ) , & [ 0 , 2 , 3 , 4 ] ) ;
3511
+
3512
+ // Test empty array with non-zero dimensions in other axes
3513
+ let empty_mixed = Array2 :: < i32 > :: zeros ( ( 0 , 5 ) ) ;
3514
+ let result_mixed = empty_mixed. partition ( 0 , Axis ( 0 ) ) ;
3515
+ assert_eq ! ( result_mixed. shape( ) , & [ 0 , 5 ] ) ;
3516
+
3517
+ // Test empty array with negative strides
3518
+ let arr = Array2 :: < i32 > :: zeros ( ( 3 , 3 ) ) ;
3519
+ let empty_slice = arr. slice ( s ! [ 0 ..0 , ..] ) ;
3520
+ let result_slice = empty_slice. partition ( 0 , Axis ( 0 ) ) ;
3521
+ assert_eq ! ( result_slice. shape( ) , & [ 0 , 3 ] ) ;
3522
+ }
3478
3523
}
0 commit comments