@@ -3217,55 +3217,44 @@ impl<A, D: Dimension> ArrayRef<A, D>
32173217 /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
32183218 /// ```
32193219 pub fn partition ( & self , kth : usize , axis : Axis ) -> Array < A , D >
3220- where A : Clone + Ord
3220+ where
3221+ A : Clone + Ord ,
3222+ D : Dimension ,
32213223 {
3222- // Check if axis is valid
3223- if axis. index ( ) >= self . ndim ( ) {
3224- panic ! ( "axis {} is out of bounds for array of dimension {}" , axis. index( ) , self . ndim( ) ) ;
3225- }
3226-
3227- // Check if kth is valid
3228- if kth >= self . len_of ( axis) {
3229- panic ! ( "kth {} is out of bounds for axis {} with length {}" , kth, axis. index( ) , self . len_of( axis) ) ;
3230- }
3231-
3232- // If the array is empty, return a copy
3233- if self . is_empty ( ) {
3234- return self . to_owned ( ) ;
3224+ // Bounds checking
3225+ let axis_len = self . len_of ( axis) ;
3226+ if kth >= axis_len {
3227+ panic ! ( "partition index {} is out of bounds for axis of length {}" , kth, axis_len) ;
32353228 }
32363229
3237- // If the array is 1D, handle as a special case
3238- if self . ndim ( ) == 1 {
3239- let mut result = self . to_owned ( ) ;
3240- if let Some ( slice) = result. as_slice_mut ( ) {
3241- slice. select_nth_unstable ( kth) ;
3242- }
3243- return result;
3244- }
3245-
3246- // For multi-dimensional arrays, partition along the specified axis
32473230 let mut result = self . to_owned ( ) ;
3248-
3249- // Process each lane with partitioning
3250- Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3251- // For each lane, we need a contiguous slice to partition
3252- if let Some ( slice) = lane. as_slice_mut ( ) {
3253- // If the lane's memory is contiguous, use select_nth_unstable directly
3254- slice. select_nth_unstable ( kth) ;
3255- } else {
3256- // For non-contiguous memory, create a temporary vector
3257- let mut values = lane. iter ( ) . cloned ( ) . collect :: < Vec < _ > > ( ) ;
3258-
3259- // Partition the vector
3260- values. select_nth_unstable ( kth) ;
3261-
3262- // Copy values back to the lane
3263- Zip :: from ( & mut lane) . and ( & values) . for_each ( |dest, src| {
3231+
3232+ // Check if the first lane is contiguous
3233+ let is_contiguous = result. lanes_mut ( axis)
3234+ . into_iter ( )
3235+ . next ( )
3236+ . map ( |lane| lane. is_contiguous ( ) )
3237+ . unwrap_or ( false ) ;
3238+
3239+ if is_contiguous {
3240+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3241+ lane. as_slice_mut ( ) . unwrap ( ) . select_nth_unstable ( kth) ;
3242+ } ) ;
3243+ } else {
3244+ let mut temp_vec = Vec :: with_capacity ( axis_len) ;
3245+
3246+ Zip :: from ( result. lanes_mut ( axis) ) . for_each ( |mut lane| {
3247+ temp_vec. clear ( ) ;
3248+ temp_vec. extend ( lane. iter ( ) . cloned ( ) ) ;
3249+
3250+ temp_vec. select_nth_unstable ( kth) ;
3251+
3252+ Zip :: from ( & mut lane) . and ( & temp_vec) . for_each ( |dest, src| {
32643253 * dest = src. clone ( ) ;
32653254 } ) ;
3266- }
3267- } ) ;
3268-
3255+ } ) ;
3256+ }
3257+
32693258 result
32703259 }
32713260}
@@ -3363,64 +3352,47 @@ mod tests
33633352 }
33643353
33653354 #[ test]
3366- fn test_partition_1d ( )
3367- {
3368- let a = array ! [ 7 , 1 , 5 , 2 , 6 , 0 , 3 , 4 ] ;
3369- let kth = 3 ;
3370- let p = a. partition ( kth, Axis ( 0 ) ) ;
3371-
3372- // The element at position kth is in its sorted position
3373- assert_eq ! ( p[ kth] , 3 ) ;
3374-
3375- // All elements to the left are less than or equal to the kth element
3376- for i in 0 ..kth {
3377- assert ! ( p[ i] <= p[ kth] ) ;
3378- }
3379-
3380- // All elements to the right are greater than or equal to the kth element
3381- for i in ( kth + 1 ) ..p. len ( ) {
3382- assert ! ( p[ i] >= p[ kth] ) ;
3383- }
3355+ fn test_partition_1d ( ) {
3356+ // Test partitioning a 1D array
3357+ let array = arr1 ( & [ 3 , 1 , 4 , 1 , 5 , 9 , 2 , 6 ] ) ;
3358+ let result = array. partition ( 3 , Axis ( 0 ) ) ;
3359+ // After partitioning, the element at index 3 should be in its final sorted position
3360+ assert ! ( result. slice( s![ ..3 ] ) . iter( ) . all( |& x| x <= result[ 3 ] ) ) ;
3361+ assert ! ( result. slice( s![ 4 ..] ) . iter( ) . all( |& x| x >= result[ 3 ] ) ) ;
33843362 }
33853363
33863364 #[ test]
3387- fn test_partition_2d ( )
3388- {
3389- let a = array ! [ [ 7 , 1 , 5 ] , [ 2 , 6 , 0 ] , [ 3 , 4 , 8 ] ] ;
3390-
3365+ fn test_partition_2d ( ) {
3366+ // Test partitioning a 2D array along both axes
3367+ let array = arr2 ( & [ [ 3 , 1 , 4 ] , [ 1 , 5 , 9 ] , [ 2 , 6 , 5 ] ] ) ;
3368+
33913369 // Partition along axis 0 (rows)
3392- let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
3393-
3394- // For each column, the middle row should be in its sorted position
3395- for col in 0 ..3 {
3396- assert ! ( p_axis0[ [ 0 , col] ] <= p_axis0[ [ 1 , col] ] ) ;
3397- assert ! ( p_axis0[ [ 2 , col] ] >= p_axis0[ [ 1 , col] ] ) ;
3398- }
3399-
3370+ let result0 = array. partition ( 1 , Axis ( 0 ) ) ;
3371+ // After partitioning along axis 0, each column should have its middle element in the correct position
3372+ assert ! ( result0[ [ 0 , 0 ] ] <= result0[ [ 1 , 0 ] ] && result0[ [ 2 , 0 ] ] >= result0[ [ 1 , 0 ] ] ) ;
3373+ assert ! ( result0[ [ 0 , 1 ] ] <= result0[ [ 1 , 1 ] ] && result0[ [ 2 , 1 ] ] >= result0[ [ 1 , 1 ] ] ) ;
3374+ assert ! ( result0[ [ 0 , 2 ] ] <= result0[ [ 1 , 2 ] ] && result0[ [ 2 , 2 ] ] >= result0[ [ 1 , 2 ] ] ) ;
3375+
34003376 // Partition along axis 1 (columns)
3401- let p_axis1 = a. partition ( 1 , Axis ( 1 ) ) ;
3402-
3403- // For each row, the middle column should be in its sorted position
3404- for row in 0 ..3 {
3405- assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ) ;
3406- assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ) ;
3407- }
3377+ let result1 = array. partition ( 1 , Axis ( 1 ) ) ;
3378+ // After partitioning along axis 1, each row should have its middle element in the correct position
3379+ assert ! ( result1[ [ 0 , 0 ] ] <= result1[ [ 0 , 1 ] ] && result1[ [ 0 , 2 ] ] >= result1[ [ 0 , 1 ] ] ) ;
3380+ assert ! ( result1[ [ 1 , 0 ] ] <= result1[ [ 1 , 1 ] ] && result1[ [ 1 , 2 ] ] >= result1[ [ 1 , 1 ] ] ) ;
3381+ assert ! ( result1[ [ 2 , 0 ] ] <= result1[ [ 2 , 1 ] ] && result1[ [ 2 , 2 ] ] >= result1[ [ 2 , 1 ] ] ) ;
34083382 }
34093383
34103384 #[ test]
3411- fn test_partition_3d ( )
3412- {
3413- let a = arr3 ( & [ [ [ 9 , 2 ] , [ 3 , 4 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ] ) ;
3414-
3415- // Partition along the last axis
3416- let p = a. partition ( 0 , Axis ( 2 ) ) ;
3417-
3418- // Check the partitioning along the last axis
3419- for i in 0 ..2 {
3420- for j in 0 ..2 {
3421- assert ! ( p[ [ i, j, 0 ] ] <= p[ [ i, j, 1 ] ] ) ;
3422- }
3423- }
3385+ fn test_partition_3d ( ) {
3386+ // Test partitioning a 3D array
3387+ let array = arr3 ( & [ [ [ 3 , 1 ] , [ 4 , 1 ] ] , [ [ 5 , 9 ] , [ 2 , 6 ] ] ] ) ;
3388+
3389+ // Partition along axis 0
3390+ let result = array. partition ( 0 , Axis ( 0 ) ) ;
3391+ // After partitioning, each 2x2 slice should have its first element in the correct position
3392+ assert ! ( result[ [ 0 , 0 , 0 ] ] <= result[ [ 1 , 0 , 0 ] ] ) ;
3393+ assert ! ( result[ [ 0 , 0 , 1 ] ] <= result[ [ 1 , 0 , 1 ] ] ) ;
3394+ assert ! ( result[ [ 0 , 1 , 0 ] ] <= result[ [ 1 , 1 , 0 ] ] ) ;
3395+ assert ! ( result[ [ 0 , 1 , 1 ] ] <= result[ [ 1 , 1 , 1 ] ] ) ;
34243396 }
34253397
34263398 #[ test]
@@ -3440,4 +3412,56 @@ mod tests
34403412 // This should panic because axis=1 is out of bounds for a 1D array
34413413 let _ = a. partition ( 0 , Axis ( 1 ) ) ;
34423414 }
3415+
3416+ #[ test]
3417+ fn test_partition_contiguous_or_not ( )
3418+ {
3419+ // Test contiguous case (C-order)
3420+ let a = array ! [
3421+ [ 7 , 1 , 5 ] ,
3422+ [ 2 , 6 , 0 ] ,
3423+ [ 3 , 4 , 8 ]
3424+ ] ;
3425+
3426+ // Partition along axis 0 (contiguous)
3427+ let p_axis0 = a. partition ( 1 , Axis ( 0 ) ) ;
3428+
3429+ // For each column, verify the partitioning:
3430+ // - First row should be <= middle row (kth element)
3431+ // - Last row should be >= middle row (kth element)
3432+ for col in 0 ..3 {
3433+ let kth = p_axis0[ [ 1 , col] ] ;
3434+ assert ! ( p_axis0[ [ 0 , col] ] <= kth,
3435+ "Column {}: First row {} should be <= middle row {}" ,
3436+ col, p_axis0[ [ 0 , col] ] , kth) ;
3437+ assert ! ( p_axis0[ [ 2 , col] ] >= kth,
3438+ "Column {}: Last row {} should be >= middle row {}" ,
3439+ col, p_axis0[ [ 2 , col] ] , kth) ;
3440+ }
3441+
3442+ // Test non-contiguous case (F-order)
3443+ let a = array ! [
3444+ [ 7 , 1 , 5 ] ,
3445+ [ 2 , 6 , 0 ] ,
3446+ [ 3 , 4 , 8 ]
3447+ ] ;
3448+
3449+ // Make array non-contiguous by transposing
3450+ let a = a. t ( ) . to_owned ( ) ;
3451+
3452+ // Partition along axis 1 (non-contiguous)
3453+ let p_axis1 = a. partition ( 1 , Axis ( 1 ) ) ;
3454+
3455+ // For each row, verify the partitioning:
3456+ // - First column should be <= middle column
3457+ // - Last column should be >= middle column
3458+ for row in 0 ..3 {
3459+ assert ! ( p_axis1[ [ row, 0 ] ] <= p_axis1[ [ row, 1 ] ] ,
3460+ "Row {}: First column {} should be <= middle column {}" ,
3461+ row, p_axis1[ [ row, 0 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3462+ assert ! ( p_axis1[ [ row, 2 ] ] >= p_axis1[ [ row, 1 ] ] ,
3463+ "Row {}: Last column {} should be >= middle column {}" ,
3464+ row, p_axis1[ [ row, 2 ] ] , p_axis1[ [ row, 1 ] ] ) ;
3465+ }
3466+ }
34433467}
0 commit comments