From b800bec4cddde8093af839aa377c111271442ecf Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Fri, 28 Mar 2025 00:46:48 +0900 Subject: [PATCH 1/6] fn partition, first draft --- src/impl_methods.rs | 165 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index d2f04ef1f..dc2889ba0 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3184,6 +3184,92 @@ impl ArrayRef f(&*prev, &mut *curr) }); } + + /// Return a partitioned copy of the array. + /// + /// Creates a copy of the array and partially sorts it around the k-th element along the given axis. + /// The k-th element will be in its sorted position, with: + /// - All elements smaller than the k-th element to its left + /// - All elements equal or greater than the k-th element to its right + /// - The ordering within each partition is undefined + /// + /// # Parameters + /// + /// * `kth` - Index to partition by. The k-th element will be in its sorted position. + /// * `axis` - Axis along which to partition. Default is the last axis (`Axis(ndim-1)`). + /// + /// # Returns + /// + /// A new array of the same shape and type as the input array, with elements partitioned. + /// + /// # Examples + /// + /// ``` + /// use ndarray::prelude::*; + /// + /// let a = array![7, 1, 5, 2, 6, 0, 3, 4]; + /// let p = a.partition(3, Axis(0)); + /// + /// // The element at position 3 is now 3, with smaller elements to the left + /// // and greater elements to the right + /// assert_eq!(p[3], 3); + /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3)); + /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3)); + /// ``` + pub fn partition(&self, kth: usize, axis: Axis) -> Array + where A: Clone + Ord + { + // Check if axis is valid + if axis.index() >= self.ndim() { + panic!("axis {} is out of bounds for array of dimension {}", axis.index(), self.ndim()); + } + + // Check if kth is valid + if kth >= self.len_of(axis) { + panic!("kth {} is out of bounds for axis {} with length {}", kth, axis.index(), self.len_of(axis)); + } + + // If the array is empty, return a copy + if self.is_empty() { + return self.to_owned(); + } + + // If the array is 1D, handle as a special case + if self.ndim() == 1 { + let mut result = self.to_owned(); + if let Some(slice) = result.as_slice_mut() { + slice.select_nth_unstable(kth); + } + return result; + } + + // For multi-dimensional arrays, partition along the specified axis + let mut result = self.to_owned(); + + // Use Zip to efficiently iterate over the lanes + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + // For each lane, perform the partitioning operation + if let Some(slice) = lane.as_slice_mut() { + // If the lane's memory is contiguous, use select_nth_unstable directly + slice.select_nth_unstable(kth); + } else { + // For non-contiguous memory, create a temporary array with contiguous memory + let mut temp_arr = Array::from_iter(lane.iter().cloned()); + + // Partition the temporary array + if let Some(slice) = temp_arr.as_slice_mut() { + slice.select_nth_unstable(kth); + } + + // Copy values back to original lane + Zip::from(&mut lane).and(&temp_arr).for_each(|dest, src| { + *dest = src.clone(); + }); + } + }); + + result + } } /// Transmute from A to B. @@ -3277,4 +3363,83 @@ mod tests let _a2 = a.clone(); assert_first!(a); } + + #[test] + fn test_partition_1d() + { + let a = array![7, 1, 5, 2, 6, 0, 3, 4]; + let kth = 3; + let p = a.partition(kth, Axis(0)); + + // The element at position kth is in its sorted position + assert_eq!(p[kth], 3); + + // All elements to the left are less than or equal to the kth element + for i in 0..kth { + assert!(p[i] <= p[kth]); + } + + // All elements to the right are greater than or equal to the kth element + for i in (kth + 1)..p.len() { + assert!(p[i] >= p[kth]); + } + } + + #[test] + fn test_partition_2d() + { + let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; + + // Partition along axis 0 (rows) + let p_axis0 = a.partition(1, Axis(0)); + + // For each column, the middle row should be in its sorted position + for col in 0..3 { + assert!(p_axis0[[0, col]] <= p_axis0[[1, col]]); + assert!(p_axis0[[2, col]] >= p_axis0[[1, col]]); + } + + // Partition along axis 1 (columns) + let p_axis1 = a.partition(1, Axis(1)); + + // For each row, the middle column should be in its sorted position + for row in 0..3 { + assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]]); + assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]]); + } + } + + #[test] + fn test_partition_3d() + { + let a = arr3(&[[[9, 2], [3, 4]], [[5, 6], [7, 8]]]); + + // Partition along the last axis + let p = a.partition(0, Axis(2)); + + // Check the partitioning along the last axis + for i in 0..2 { + for j in 0..2 { + assert!(p[[i, j, 0]] <= p[[i, j, 1]]); + } + } + } + + #[test] + #[should_panic] + fn test_partition_invalid_kth() + { + let a = array![1, 2, 3, 4]; + // This should panic because kth=4 is out of bounds + let _ = a.partition(4, Axis(0)); + } + + #[test] + #[should_panic] + fn test_partition_invalid_axis() + { + let a = array![1, 2, 3, 4]; + // This should panic because axis=1 is out of bounds for a 1D array + let _ = a.partition(0, Axis(1)); + } } From df5917908c6ae0f3a6fb8b50b1dd0e0dffa3137f Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Fri, 28 Mar 2025 11:08:43 +0900 Subject: [PATCH 2/6] handle non-contiguous with create and copy back --- src/impl_methods.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index dc2889ba0..106b3add5 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3196,7 +3196,7 @@ impl ArrayRef /// # Parameters /// /// * `kth` - Index to partition by. The k-th element will be in its sorted position. - /// * `axis` - Axis along which to partition. Default is the last axis (`Axis(ndim-1)`). + /// * `axis` - Axis along which to partition. /// /// # Returns /// @@ -3246,23 +3246,21 @@ impl ArrayRef // For multi-dimensional arrays, partition along the specified axis let mut result = self.to_owned(); - // Use Zip to efficiently iterate over the lanes + // Process each lane with partitioning Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { - // For each lane, perform the partitioning operation + // For each lane, we need a contiguous slice to partition if let Some(slice) = lane.as_slice_mut() { // If the lane's memory is contiguous, use select_nth_unstable directly slice.select_nth_unstable(kth); } else { - // For non-contiguous memory, create a temporary array with contiguous memory - let mut temp_arr = Array::from_iter(lane.iter().cloned()); + // For non-contiguous memory, create a temporary vector + let mut values = lane.iter().cloned().collect::>(); - // Partition the temporary array - if let Some(slice) = temp_arr.as_slice_mut() { - slice.select_nth_unstable(kth); - } + // Partition the vector + values.select_nth_unstable(kth); - // Copy values back to original lane - Zip::from(&mut lane).and(&temp_arr).for_each(|dest, src| { + // Copy values back to the lane + Zip::from(&mut lane).and(&values).for_each(|dest, src| { *dest = src.clone(); }); } From 8f721177b75b76ec9deee2dd063cf5e5eefcbc58 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Thu, 3 Apr 2025 22:39:41 +0900 Subject: [PATCH 3/6] include review: not to allocate n-vector, reuse a single vector --- src/impl_methods.rs | 210 ++++++++++++++++++++++++-------------------- 1 file changed, 117 insertions(+), 93 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 106b3add5..31809b97d 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3217,55 +3217,44 @@ impl ArrayRef /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3)); /// ``` pub fn partition(&self, kth: usize, axis: Axis) -> Array - where A: Clone + Ord + where + A: Clone + Ord, + D: Dimension, { - // Check if axis is valid - if axis.index() >= self.ndim() { - panic!("axis {} is out of bounds for array of dimension {}", axis.index(), self.ndim()); - } - - // Check if kth is valid - if kth >= self.len_of(axis) { - panic!("kth {} is out of bounds for axis {} with length {}", kth, axis.index(), self.len_of(axis)); - } - - // If the array is empty, return a copy - if self.is_empty() { - return self.to_owned(); + // Bounds checking + let axis_len = self.len_of(axis); + if kth >= axis_len { + panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len); } - // If the array is 1D, handle as a special case - if self.ndim() == 1 { - let mut result = self.to_owned(); - if let Some(slice) = result.as_slice_mut() { - slice.select_nth_unstable(kth); - } - return result; - } - - // For multi-dimensional arrays, partition along the specified axis let mut result = self.to_owned(); - - // Process each lane with partitioning - Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { - // For each lane, we need a contiguous slice to partition - if let Some(slice) = lane.as_slice_mut() { - // If the lane's memory is contiguous, use select_nth_unstable directly - slice.select_nth_unstable(kth); - } else { - // For non-contiguous memory, create a temporary vector - let mut values = lane.iter().cloned().collect::>(); - - // Partition the vector - values.select_nth_unstable(kth); - - // Copy values back to the lane - Zip::from(&mut lane).and(&values).for_each(|dest, src| { + + // Check if the first lane is contiguous + let is_contiguous = result.lanes_mut(axis) + .into_iter() + .next() + .map(|lane| lane.is_contiguous()) + .unwrap_or(false); + + if is_contiguous { + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + lane.as_slice_mut().unwrap().select_nth_unstable(kth); + }); + } else { + let mut temp_vec = Vec::with_capacity(axis_len); + + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + temp_vec.clear(); + temp_vec.extend(lane.iter().cloned()); + + temp_vec.select_nth_unstable(kth); + + Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| { *dest = src.clone(); }); - } - }); - + }); + } + result } } @@ -3363,64 +3352,47 @@ mod tests } #[test] - fn test_partition_1d() - { - let a = array![7, 1, 5, 2, 6, 0, 3, 4]; - let kth = 3; - let p = a.partition(kth, Axis(0)); - - // The element at position kth is in its sorted position - assert_eq!(p[kth], 3); - - // All elements to the left are less than or equal to the kth element - for i in 0..kth { - assert!(p[i] <= p[kth]); - } - - // All elements to the right are greater than or equal to the kth element - for i in (kth + 1)..p.len() { - assert!(p[i] >= p[kth]); - } + fn test_partition_1d() { + // Test partitioning a 1D array + let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]); + let result = array.partition(3, Axis(0)); + // After partitioning, the element at index 3 should be in its final sorted position + assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3])); + assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3])); } #[test] - fn test_partition_2d() - { - let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; - + fn test_partition_2d() { + // Test partitioning a 2D array along both axes + let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]); + // Partition along axis 0 (rows) - let p_axis0 = a.partition(1, Axis(0)); - - // For each column, the middle row should be in its sorted position - for col in 0..3 { - assert!(p_axis0[[0, col]] <= p_axis0[[1, col]]); - assert!(p_axis0[[2, col]] >= p_axis0[[1, col]]); - } - + let result0 = array.partition(1, Axis(0)); + // After partitioning along axis 0, each column should have its middle element in the correct position + assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]); + assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]); + assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]); + // Partition along axis 1 (columns) - let p_axis1 = a.partition(1, Axis(1)); - - // For each row, the middle column should be in its sorted position - for row in 0..3 { - assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]]); - assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]]); - } + let result1 = array.partition(1, Axis(1)); + // After partitioning along axis 1, each row should have its middle element in the correct position + assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]); + assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]); + assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]); } #[test] - fn test_partition_3d() - { - let a = arr3(&[[[9, 2], [3, 4]], [[5, 6], [7, 8]]]); - - // Partition along the last axis - let p = a.partition(0, Axis(2)); - - // Check the partitioning along the last axis - for i in 0..2 { - for j in 0..2 { - assert!(p[[i, j, 0]] <= p[[i, j, 1]]); - } - } + fn test_partition_3d() { + // Test partitioning a 3D array + let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]); + + // Partition along axis 0 + let result = array.partition(0, Axis(0)); + // After partitioning, each 2x2 slice should have its first element in the correct position + assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]); + assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]); + assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]); + assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]); } #[test] @@ -3440,4 +3412,56 @@ mod tests // This should panic because axis=1 is out of bounds for a 1D array let _ = a.partition(0, Axis(1)); } + + #[test] + fn test_partition_contiguous_or_not() + { + // Test contiguous case (C-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Partition along axis 0 (contiguous) + let p_axis0 = a.partition(1, Axis(0)); + + // For each column, verify the partitioning: + // - First row should be <= middle row (kth element) + // - Last row should be >= middle row (kth element) + for col in 0..3 { + let kth = p_axis0[[1, col]]; + assert!(p_axis0[[0, col]] <= kth, + "Column {}: First row {} should be <= middle row {}", + col, p_axis0[[0, col]], kth); + assert!(p_axis0[[2, col]] >= kth, + "Column {}: Last row {} should be >= middle row {}", + col, p_axis0[[2, col]], kth); + } + + // Test non-contiguous case (F-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Make array non-contiguous by transposing + let a = a.t().to_owned(); + + // Partition along axis 1 (non-contiguous) + let p_axis1 = a.partition(1, Axis(1)); + + // For each row, verify the partitioning: + // - First column should be <= middle column + // - Last column should be >= middle column + for row in 0..3 { + assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]], + "Row {}: First column {} should be <= middle column {}", + row, p_axis1[[row, 0]], p_axis1[[row, 1]]); + assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]], + "Row {}: Last column {} should be >= middle column {}", + row, p_axis1[[row, 2]], p_axis1[[row, 1]]); + } + } } From a0df0660cd3d682dc983ba0ca9b47e6adc47b644 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Thu, 3 Apr 2025 22:41:30 +0900 Subject: [PATCH 4/6] format nightly --- src/impl_methods.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 31809b97d..b289d4720 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3228,33 +3228,34 @@ impl ArrayRef } let mut result = self.to_owned(); - + // Check if the first lane is contiguous - let is_contiguous = result.lanes_mut(axis) + let is_contiguous = result + .lanes_mut(axis) .into_iter() .next() .map(|lane| lane.is_contiguous()) .unwrap_or(false); - + if is_contiguous { Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { lane.as_slice_mut().unwrap().select_nth_unstable(kth); }); } else { let mut temp_vec = Vec::with_capacity(axis_len); - + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { temp_vec.clear(); temp_vec.extend(lane.iter().cloned()); - + temp_vec.select_nth_unstable(kth); - + Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| { *dest = src.clone(); }); }); } - + result } } @@ -3352,7 +3353,8 @@ mod tests } #[test] - fn test_partition_1d() { + fn test_partition_1d() + { // Test partitioning a 1D array let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]); let result = array.partition(3, Axis(0)); @@ -3362,17 +3364,18 @@ mod tests } #[test] - fn test_partition_2d() { + fn test_partition_2d() + { // Test partitioning a 2D array along both axes let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]); - + // Partition along axis 0 (rows) let result0 = array.partition(1, Axis(0)); // After partitioning along axis 0, each column should have its middle element in the correct position assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]); assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]); assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]); - + // Partition along axis 1 (columns) let result1 = array.partition(1, Axis(1)); // After partitioning along axis 1, each row should have its middle element in the correct position @@ -3382,10 +3385,11 @@ mod tests } #[test] - fn test_partition_3d() { + fn test_partition_3d() + { // Test partitioning a 3D array let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]); - + // Partition along axis 0 let result = array.partition(0, Axis(0)); // After partitioning, each 2x2 slice should have its first element in the correct position From 3344cf8b73fe8c5760d3c0cbdad31f5b52078722 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sat, 5 Apr 2025 08:22:59 +0900 Subject: [PATCH 5/6] applied feedback: safe unwrap(), avoid navie iter() --- src/impl_methods.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index b289d4720..3d5be43c6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3218,7 +3218,7 @@ impl ArrayRef /// ``` pub fn partition(&self, kth: usize, axis: Axis) -> Array where - A: Clone + Ord, + A: Clone + Ord + num_traits::Zero, D: Dimension, { // Bounds checking @@ -3234,8 +3234,8 @@ impl ArrayRef .lanes_mut(axis) .into_iter() .next() - .map(|lane| lane.is_contiguous()) - .unwrap_or(false); + .unwrap() + .is_contiguous(); if is_contiguous { Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { @@ -3246,7 +3246,11 @@ impl ArrayRef Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { temp_vec.clear(); - temp_vec.extend(lane.iter().cloned()); + temp_vec.resize(axis_len, A::zero()); + + Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| { + *dest = src.clone(); + }); temp_vec.select_nth_unstable(kth); From 05dee0e2deab58c5a8e69f6946651273e38160e3 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sat, 5 Apr 2025 10:29:46 +0900 Subject: [PATCH 6/6] initialize temp vector in advance --- src/impl_methods.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 3d5be43c6..42d843781 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3242,12 +3242,9 @@ impl ArrayRef lane.as_slice_mut().unwrap().select_nth_unstable(kth); }); } else { - let mut temp_vec = Vec::with_capacity(axis_len); + let mut temp_vec = vec![A::zero(); axis_len]; Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { - temp_vec.clear(); - temp_vec.resize(axis_len, A::zero()); - Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| { *dest = src.clone(); });