Skip to content

Commit 8f72117

Browse files
include review: not to allocate n-vector, reuse a single vector
1 parent 8034819 commit 8f72117

File tree

1 file changed

+117
-93
lines changed

1 file changed

+117
-93
lines changed

src/impl_methods.rs

Lines changed: 117 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3217,55 +3217,44 @@ impl<A, D: Dimension> ArrayRef<A, D>
32173217
/// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3));
32183218
/// ```
32193219
pub fn partition(&self, kth: usize, axis: Axis) -> Array<A, D>
3220-
where A: Clone + Ord
3220+
where
3221+
A: Clone + Ord,
3222+
D: Dimension,
32213223
{
3222-
// Check if axis is valid
3223-
if axis.index() >= self.ndim() {
3224-
panic!("axis {} is out of bounds for array of dimension {}", axis.index(), self.ndim());
3225-
}
3226-
3227-
// Check if kth is valid
3228-
if kth >= self.len_of(axis) {
3229-
panic!("kth {} is out of bounds for axis {} with length {}", kth, axis.index(), self.len_of(axis));
3230-
}
3231-
3232-
// If the array is empty, return a copy
3233-
if self.is_empty() {
3234-
return self.to_owned();
3224+
// Bounds checking
3225+
let axis_len = self.len_of(axis);
3226+
if kth >= axis_len {
3227+
panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len);
32353228
}
32363229

3237-
// If the array is 1D, handle as a special case
3238-
if self.ndim() == 1 {
3239-
let mut result = self.to_owned();
3240-
if let Some(slice) = result.as_slice_mut() {
3241-
slice.select_nth_unstable(kth);
3242-
}
3243-
return result;
3244-
}
3245-
3246-
// For multi-dimensional arrays, partition along the specified axis
32473230
let mut result = self.to_owned();
3248-
3249-
// Process each lane with partitioning
3250-
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3251-
// For each lane, we need a contiguous slice to partition
3252-
if let Some(slice) = lane.as_slice_mut() {
3253-
// If the lane's memory is contiguous, use select_nth_unstable directly
3254-
slice.select_nth_unstable(kth);
3255-
} else {
3256-
// For non-contiguous memory, create a temporary vector
3257-
let mut values = lane.iter().cloned().collect::<Vec<_>>();
3258-
3259-
// Partition the vector
3260-
values.select_nth_unstable(kth);
3261-
3262-
// Copy values back to the lane
3263-
Zip::from(&mut lane).and(&values).for_each(|dest, src| {
3231+
3232+
// Check if the first lane is contiguous
3233+
let is_contiguous = result.lanes_mut(axis)
3234+
.into_iter()
3235+
.next()
3236+
.map(|lane| lane.is_contiguous())
3237+
.unwrap_or(false);
3238+
3239+
if is_contiguous {
3240+
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3241+
lane.as_slice_mut().unwrap().select_nth_unstable(kth);
3242+
});
3243+
} else {
3244+
let mut temp_vec = Vec::with_capacity(axis_len);
3245+
3246+
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3247+
temp_vec.clear();
3248+
temp_vec.extend(lane.iter().cloned());
3249+
3250+
temp_vec.select_nth_unstable(kth);
3251+
3252+
Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| {
32643253
*dest = src.clone();
32653254
});
3266-
}
3267-
});
3268-
3255+
});
3256+
}
3257+
32693258
result
32703259
}
32713260
}
@@ -3363,64 +3352,47 @@ mod tests
33633352
}
33643353

33653354
#[test]
3366-
fn test_partition_1d()
3367-
{
3368-
let a = array![7, 1, 5, 2, 6, 0, 3, 4];
3369-
let kth = 3;
3370-
let p = a.partition(kth, Axis(0));
3371-
3372-
// The element at position kth is in its sorted position
3373-
assert_eq!(p[kth], 3);
3374-
3375-
// All elements to the left are less than or equal to the kth element
3376-
for i in 0..kth {
3377-
assert!(p[i] <= p[kth]);
3378-
}
3379-
3380-
// All elements to the right are greater than or equal to the kth element
3381-
for i in (kth + 1)..p.len() {
3382-
assert!(p[i] >= p[kth]);
3383-
}
3355+
fn test_partition_1d() {
3356+
// Test partitioning a 1D array
3357+
let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]);
3358+
let result = array.partition(3, Axis(0));
3359+
// After partitioning, the element at index 3 should be in its final sorted position
3360+
assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3]));
3361+
assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3]));
33843362
}
33853363

33863364
#[test]
3387-
fn test_partition_2d()
3388-
{
3389-
let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]];
3390-
3365+
fn test_partition_2d() {
3366+
// Test partitioning a 2D array along both axes
3367+
let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]);
3368+
33913369
// Partition along axis 0 (rows)
3392-
let p_axis0 = a.partition(1, Axis(0));
3393-
3394-
// For each column, the middle row should be in its sorted position
3395-
for col in 0..3 {
3396-
assert!(p_axis0[[0, col]] <= p_axis0[[1, col]]);
3397-
assert!(p_axis0[[2, col]] >= p_axis0[[1, col]]);
3398-
}
3399-
3370+
let result0 = array.partition(1, Axis(0));
3371+
// After partitioning along axis 0, each column should have its middle element in the correct position
3372+
assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]);
3373+
assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]);
3374+
assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]);
3375+
34003376
// Partition along axis 1 (columns)
3401-
let p_axis1 = a.partition(1, Axis(1));
3402-
3403-
// For each row, the middle column should be in its sorted position
3404-
for row in 0..3 {
3405-
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]]);
3406-
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]]);
3407-
}
3377+
let result1 = array.partition(1, Axis(1));
3378+
// After partitioning along axis 1, each row should have its middle element in the correct position
3379+
assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]);
3380+
assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]);
3381+
assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]);
34083382
}
34093383

34103384
#[test]
3411-
fn test_partition_3d()
3412-
{
3413-
let a = arr3(&[[[9, 2], [3, 4]], [[5, 6], [7, 8]]]);
3414-
3415-
// Partition along the last axis
3416-
let p = a.partition(0, Axis(2));
3417-
3418-
// Check the partitioning along the last axis
3419-
for i in 0..2 {
3420-
for j in 0..2 {
3421-
assert!(p[[i, j, 0]] <= p[[i, j, 1]]);
3422-
}
3423-
}
3385+
fn test_partition_3d() {
3386+
// Test partitioning a 3D array
3387+
let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]);
3388+
3389+
// Partition along axis 0
3390+
let result = array.partition(0, Axis(0));
3391+
// After partitioning, each 2x2 slice should have its first element in the correct position
3392+
assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]);
3393+
assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]);
3394+
assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]);
3395+
assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]);
34243396
}
34253397

34263398
#[test]
@@ -3440,4 +3412,56 @@ mod tests
34403412
// This should panic because axis=1 is out of bounds for a 1D array
34413413
let _ = a.partition(0, Axis(1));
34423414
}
3415+
3416+
#[test]
3417+
fn test_partition_contiguous_or_not()
3418+
{
3419+
// Test contiguous case (C-order)
3420+
let a = array![
3421+
[7, 1, 5],
3422+
[2, 6, 0],
3423+
[3, 4, 8]
3424+
];
3425+
3426+
// Partition along axis 0 (contiguous)
3427+
let p_axis0 = a.partition(1, Axis(0));
3428+
3429+
// For each column, verify the partitioning:
3430+
// - First row should be <= middle row (kth element)
3431+
// - Last row should be >= middle row (kth element)
3432+
for col in 0..3 {
3433+
let kth = p_axis0[[1, col]];
3434+
assert!(p_axis0[[0, col]] <= kth,
3435+
"Column {}: First row {} should be <= middle row {}",
3436+
col, p_axis0[[0, col]], kth);
3437+
assert!(p_axis0[[2, col]] >= kth,
3438+
"Column {}: Last row {} should be >= middle row {}",
3439+
col, p_axis0[[2, col]], kth);
3440+
}
3441+
3442+
// Test non-contiguous case (F-order)
3443+
let a = array![
3444+
[7, 1, 5],
3445+
[2, 6, 0],
3446+
[3, 4, 8]
3447+
];
3448+
3449+
// Make array non-contiguous by transposing
3450+
let a = a.t().to_owned();
3451+
3452+
// Partition along axis 1 (non-contiguous)
3453+
let p_axis1 = a.partition(1, Axis(1));
3454+
3455+
// For each row, verify the partitioning:
3456+
// - First column should be <= middle column
3457+
// - Last column should be >= middle column
3458+
for row in 0..3 {
3459+
assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]],
3460+
"Row {}: First column {} should be <= middle column {}",
3461+
row, p_axis1[[row, 0]], p_axis1[[row, 1]]);
3462+
assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]],
3463+
"Row {}: Last column {} should be >= middle column {}",
3464+
row, p_axis1[[row, 2]], p_axis1[[row, 1]]);
3465+
}
3466+
}
34433467
}

0 commit comments

Comments
 (0)