Skip to content

Commit b800bec

Browse files
fn partition, first draft
1 parent 2a5cae1 commit b800bec

File tree

1 file changed

+165
-0
lines changed

1 file changed

+165
-0
lines changed

src/impl_methods.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3184,6 +3184,92 @@ impl<A, D: Dimension> ArrayRef<A, D>
31843184
f(&*prev, &mut *curr)
31853185
});
31863186
}
3187+
3188+
/// Return a partitioned copy of the array.
3189+
///
3190+
/// Creates a copy of the array and partially sorts it around the k-th element along the given axis.
3191+
/// The k-th element will be in its sorted position, with:
3192+
/// - All elements smaller than the k-th element to its left
3193+
/// - All elements equal or greater than the k-th element to its right
3194+
/// - The ordering within each partition is undefined
3195+
///
3196+
/// # Parameters
3197+
///
3198+
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199+
/// * `axis` - Axis along which to partition. Default is the last axis (`Axis(ndim-1)`).
3200+
///
3201+
/// # Returns
3202+
///
3203+
/// A new array of the same shape and type as the input array, with elements partitioned.
3204+
///
3205+
/// # Examples
3206+
///
3207+
/// ```
3208+
/// use ndarray::prelude::*;
3209+
///
3210+
/// let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3211+
/// let p = a.partition(3, Axis(0));
3212+
///
3213+
/// // The element at position 3 is now 3, with smaller elements to the left
3214+
/// // and greater elements to the right
3215+
/// assert_eq!(p[3], 3);
3216+
/// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3));
3217+
/// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
3218+
/// ```
3219+
pub fn partition(&self, kth: usize, axis: Axis) -> Array<A, D>
3220+
where A: Clone + Ord
3221+
{
3222+
// Check if axis is valid
3223+
if axis.index() >= self.ndim() {
3224+
panic!("axis {} is out of bounds for array of dimension {}", axis.index(), self.ndim());
3225+
}
3226+
3227+
// Check if kth is valid
3228+
if kth >= self.len_of(axis) {
3229+
panic!("kth {} is out of bounds for axis {} with length {}", kth, axis.index(), self.len_of(axis));
3230+
}
3231+
3232+
// If the array is empty, return a copy
3233+
if self.is_empty() {
3234+
return self.to_owned();
3235+
}
3236+
3237+
// If the array is 1D, handle as a special case
3238+
if self.ndim() == 1 {
3239+
let mut result = self.to_owned();
3240+
if let Some(slice) = result.as_slice_mut() {
3241+
slice.select_nth_unstable(kth);
3242+
}
3243+
return result;
3244+
}
3245+
3246+
// For multi-dimensional arrays, partition along the specified axis
3247+
let mut result = self.to_owned();
3248+
3249+
// Use Zip to efficiently iterate over the lanes
3250+
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3251+
// For each lane, perform the partitioning operation
3252+
if let Some(slice) = lane.as_slice_mut() {
3253+
// If the lane's memory is contiguous, use select_nth_unstable directly
3254+
slice.select_nth_unstable(kth);
3255+
} else {
3256+
// For non-contiguous memory, create a temporary array with contiguous memory
3257+
let mut temp_arr = Array::from_iter(lane.iter().cloned());
3258+
3259+
// Partition the temporary array
3260+
if let Some(slice) = temp_arr.as_slice_mut() {
3261+
slice.select_nth_unstable(kth);
3262+
}
3263+
3264+
// Copy values back to original lane
3265+
Zip::from(&mut lane).and(&temp_arr).for_each(|dest, src| {
3266+
*dest = src.clone();
3267+
});
3268+
}
3269+
});
3270+
3271+
result
3272+
}
31873273
}
31883274

31893275
/// Transmute from A to B.
@@ -3277,4 +3363,83 @@ mod tests
32773363
let _a2 = a.clone();
32783364
assert_first!(a);
32793365
}
3366+
3367+
#[test]
3368+
fn test_partition_1d()
3369+
{
3370+
let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3371+
let kth = 3;
3372+
let p = a.partition(kth, Axis(0));
3373+
3374+
// The element at position kth is in its sorted position
3375+
assert_eq!(p[kth], 3);
3376+
3377+
// All elements to the left are less than or equal to the kth element
3378+
for i in 0..kth {
3379+
assert!(p[i] <= p[kth]);
3380+
}
3381+
3382+
// All elements to the right are greater than or equal to the kth element
3383+
for i in (kth + 1)..p.len() {
3384+
assert!(p[i] >= p[kth]);
3385+
}
3386+
}
3387+
3388+
#[test]
3389+
fn test_partition_2d()
3390+
{
3391+
let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]];
3392+
3393+
// Partition along axis 0 (rows)
3394+
let p_axis0 = a.partition(1, Axis(0));
3395+
3396+
// For each column, the middle row should be in its sorted position
3397+
for col in 0..3 {
3398+
assert!(p_axis0[[0, col]] <= p_axis0[[1, col]]);
3399+
assert!(p_axis0[[2, col]] >= p_axis0[[1, col]]);
3400+
}
3401+
3402+
// Partition along axis 1 (columns)
3403+
let p_axis1 = a.partition(1, Axis(1));
3404+
3405+
// For each row, the middle column should be in its sorted position
3406+
for row in 0..3 {
3407+
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]]);
3408+
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]]);
3409+
}
3410+
}
3411+
3412+
#[test]
3413+
fn test_partition_3d()
3414+
{
3415+
let a = arr3(&[[[9, 2], [3, 4]], [[5, 6], [7, 8]]]);
3416+
3417+
// Partition along the last axis
3418+
let p = a.partition(0, Axis(2));
3419+
3420+
// Check the partitioning along the last axis
3421+
for i in 0..2 {
3422+
for j in 0..2 {
3423+
assert!(p[[i, j, 0]] <= p[[i, j, 1]]);
3424+
}
3425+
}
3426+
}
3427+
3428+
#[test]
3429+
#[should_panic]
3430+
fn test_partition_invalid_kth()
3431+
{
3432+
let a = array![1, 2, 3, 4];
3433+
// This should panic because kth=4 is out of bounds
3434+
let _ = a.partition(4, Axis(0));
3435+
}
3436+
3437+
#[test]
3438+
#[should_panic]
3439+
fn test_partition_invalid_axis()
3440+
{
3441+
let a = array![1, 2, 3, 4];
3442+
// This should panic because axis=1 is out of bounds for a 1D array
3443+
let _ = a.partition(0, Axis(1));
3444+
}
32803445
}

0 commit comments

Comments
 (0)