Skip to content

Commit df59179

Browse files
handle non-contiguous with create and copy back
1 parent b800bec commit df59179

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

src/impl_methods.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3196,7 +3196,7 @@ impl<A, D: Dimension> ArrayRef<A, D>
31963196
/// # Parameters
31973197
///
31983198
/// * `kth` - Index to partition by. The k-th element will be in its sorted position.
3199-
/// * `axis` - Axis along which to partition. Default is the last axis (`Axis(ndim-1)`).
3199+
/// * `axis` - Axis along which to partition.
32003200
///
32013201
/// # Returns
32023202
///
@@ -3246,23 +3246,21 @@ impl<A, D: Dimension> ArrayRef<A, D>
32463246
// For multi-dimensional arrays, partition along the specified axis
32473247
let mut result = self.to_owned();
32483248

3249-
// Use Zip to efficiently iterate over the lanes
3249+
// Process each lane with partitioning
32503250
Zip::from(result.lanes_mut(axis)).for_each(|mut lane| {
3251-
// For each lane, perform the partitioning operation
3251+
// For each lane, we need a contiguous slice to partition
32523252
if let Some(slice) = lane.as_slice_mut() {
32533253
// If the lane's memory is contiguous, use select_nth_unstable directly
32543254
slice.select_nth_unstable(kth);
32553255
} else {
3256-
// For non-contiguous memory, create a temporary array with contiguous memory
3257-
let mut temp_arr = Array::from_iter(lane.iter().cloned());
3256+
// For non-contiguous memory, create a temporary vector
3257+
let mut values = lane.iter().cloned().collect::<Vec<_>>();
32583258

3259-
// Partition the temporary array
3260-
if let Some(slice) = temp_arr.as_slice_mut() {
3261-
slice.select_nth_unstable(kth);
3262-
}
3259+
// Partition the vector
3260+
values.select_nth_unstable(kth);
32633261

3264-
// Copy values back to original lane
3265-
Zip::from(&mut lane).and(&temp_arr).for_each(|dest, src| {
3262+
// Copy values back to the lane
3263+
Zip::from(&mut lane).and(&values).for_each(|dest, src| {
32663264
*dest = src.clone();
32673265
});
32683266
}

0 commit comments

Comments
 (0)