@@ -3184,6 +3184,92 @@ impl<A, D: Dimension> ArrayRef<A, D>
31843184 f ( & * prev, & mut * curr)
31853185 } ) ;
31863186 }
3187+
3188+ /// Return a partitioned copy of the array.
3189+ ///
3190+ /// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
3191+ /// The k-th element will be in its sorted position, with:
3192+ /// - All elements smaller than the k-th element to its left
3193+ /// - All elements equal or greater than the k-th element to its right
3194+ /// - The ordering within each partition is undefined
3195+ ///
3196+ /// # Parameters
3197+ ///
3198+ /// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199+ /// * `axis` - Axis along which to partition. Default is the last axis (`Axis(ndim-1)`).
3200+ ///
3201+ /// # Returns
3202+ ///
3203+ /// A new array of the same shape and type as the input array, with elements partitioned.
3204+ ///
3205+ /// # Examples
3206+ ///
3207+ /// ```
3208+ /// use ndarray::prelude::*;
3209+ ///
3210+ /// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3211+ /// let p = a.partition(3, Axis(0));
3212+ ///
3213+ /// // The element at position 3 is now 3, with smaller elements to the left
3214+ /// // and greater elements to the right
3215+ /// assert_eq!(p[3], 3);
3216+ /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
3217+ /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
3218+ /// ```
3219+ pub fn partition ( & self , kth : usize , axis : Axis ) -> Array < A , D >
3220+ where A : Clone + Ord
3221+ {
3222+ // Check if axis is valid
3223+ if axis. index ( ) >= self . ndim ( ) {
3224+ panic ! ( "axis {} is out of bounds for array of dimension {}" , axis. index( ) , self . ndim( ) ) ;
3225+ }
3226+
3227+ // Check if kth is valid
3228+ if kth >= self . len_of ( axis) {
3229+ panic ! ( "kth {} is out of bounds for axis {} with length {}" , kth, axis. index( ) , self . len_of( axis) ) ;
3230+ }
3231+
3232+ // If the array is empty, return a copy
3233+ if self . is_empty ( ) {
3234+ return self . to_owned ( ) ;
3235+ }
3236+
3237+ // If the array is 1D, handle as a special case
3238+ if self . ndim ( ) == 1 {
3239+ let mut result = self . to_owned ( ) ;
3240+ if let Some ( slice) = result. as_slice_mut ( ) {
3241+ slice. select_nth_unstable ( kth) ;
3242+ }
3243+ return result;
3244+ }
3245+
3246+ // For multi-dimensional arrays, partition along the specified axis
3247+ let mut result = self . to_owned ( ) ;
3248+
3249+ // Use Zip to efficiently iterate over the lanes
3250+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3251+ // For each lane, perform the partitioning operation
3252+ if let Some ( slice) = lane. as_slice_mut ( ) {
3253+ // If the lane's memory is contiguous, use select_nth_unstable directly
3254+ slice. select_nth_unstable ( kth) ;
3255+ } else {
3256+ // For non-contiguous memory, create a temporary array with contiguous memory
3257+ let mut temp_arr = Array :: from_iter ( lane. iter ( ) . cloned ( ) ) ;
3258+
3259+ // Partition the temporary array
3260+ if let Some ( slice) = temp_arr. as_slice_mut ( ) {
3261+ slice. select_nth_unstable ( kth) ;
3262+ }
3263+
3264+ // Copy values back to original lane
3265+ Zip :: from ( & mut lane) . and ( & temp_arr) . for_each ( |dest, src| {
3266+ * dest = src. clone ( ) ;
3267+ } ) ;
3268+ }
3269+ } ) ;
3270+
3271+ result
3272+ }
31873273}
31883274
31893275/// Transmute from A to B.
@@ -3277,4 +3363,83 @@ mod tests
32773363 let _a2 = a. clone ( ) ;
32783364 assert_first ! ( a) ;
32793365 }
3366+
3367+ #[ test]
3368+ fn test_partition_1d ( )
3369+ {
3370+ let a = array ! [ 7 , 1 , 5 , 2 , 6 , 0 , 3 , 4 ] ;
3371+ let kth = 3 ;
3372+ let p = a. partition ( kth, Axis ( 0 ) ) ;
3373+
3374+ // The element at position kth is in its sorted position
3375+ assert_eq ! ( p[ kth] , 3 ) ;
3376+
3377+ // All elements to the left are less than or equal to the kth element
3378+ for i in 0 ..kth {
3379+ assert ! ( p[ i] <= p[ kth] ) ;
3380+ }
3381+
3382+ // All elements to the right are greater than or equal to the kth element
3383+ for i in ( kth + 1 ) ..p. len ( ) {
3384+ assert ! ( p[ i] >= p[ kth] ) ;
3385+ }
3386+ }
3387+
3388+ #[ test]
3389+ fn test_partition_2d ( )
3390+ {
3391+ let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
3392+
3393+ // Partition along axis 0 (rows)
3394+ let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
3395+
3396+ // For each column, the middle row should be in its sorted position
3397+ for col in 0 ..3 {
3398+ assert ! ( p_axis0[ [ 0 , col] ] <= p_axis0[ [ 1 , col] ] ) ;
3399+ assert ! ( p_axis0[ [ 2 , col] ] >= p_axis0[ [ 1 , col] ] ) ;
3400+ }
3401+
3402+ // Partition along axis 1 (columns)
3403+ let p_axis1 = a. partition ( 1 , Axis ( 1 ) ) ;
3404+
3405+ // For each row, the middle column should be in its sorted position
3406+ for row in 0 ..3 {
3407+ assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ) ;
3408+ assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ) ;
3409+ }
3410+ }
3411+
3412+ #[ test]
3413+ fn test_partition_3d ( )
3414+ {
3415+ let a = arr3 ( & [ [ [ 9 , 2 ] , [ 3 , 4 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ] ) ;
3416+
3417+ // Partition along the last axis
3418+ let p = a. partition ( 0 , Axis ( 2 ) ) ;
3419+
3420+ // Check the partitioning along the last axis
3421+ for i in 0 ..2 {
3422+ for j in 0 ..2 {
3423+ assert ! ( p[ [ i, j, 0 ] ] <= p[ [ i, j, 1 ] ] ) ;
3424+ }
3425+ }
3426+ }
3427+
3428+ #[ test]
3429+ #[ should_panic]
3430+ fn test_partition_invalid_kth ( )
3431+ {
3432+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3433+ // This should panic because kth=4 is out of bounds
3434+ let _ = a. partition ( 4 , Axis ( 0 ) ) ;
3435+ }
3436+
3437+ #[ test]
3438+ #[ should_panic]
3439+ fn test_partition_invalid_axis ( )
3440+ {
3441+ let a = array ! [ 1 , 2 , 3 , 4 ] ;
3442+ // This should panic because axis=1 is out of bounds for a 1D array
3443+ let _ = a. partition ( 0 , Axis ( 1 ) ) ;
3444+ }
32803445}
0 commit comments