Skip to content

Commit 8e36dd6

Browse files
draft for inplace reverse, permute
1 parent da115c9 commit 8e36dd6

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/impl_methods.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,6 +2548,72 @@ where
25482548
self.layout.strides.slice_mut().reverse();
25492549
self
25502550
}
2551+
2552+
/// Reverse the axes of the array in-place.
2553+
///
2554+
/// This does not move any data, it just adjusts the array's dimensions
2555+
/// and strides.
2556+
pub fn reverse_axes(&mut self)
2557+
{
2558+
self.layout.dim.slice_mut().reverse();
2559+
self.layout.strides.slice_mut().reverse();
2560+
}
2561+
2562+
/// Permute the axes in-place.
2563+
///
2564+
/// This does not move any data, it just adjusts the array's dimensions
2565+
/// and strides.
2566+
///
2567+
/// *i* in the *j*-th place in the axes sequence means `self`'s *i*-th axis
2568+
/// becomes `self`'s *j*-th axis
2569+
///
2570+
/// **Panics** if any of the axes are out of bounds, if an axis is missing,
2571+
/// or if an axis is repeated more than once.
2572+
///
2573+
/// # Examples
2574+
///
2575+
/// ```rust
2576+
/// use ndarray::{arr2, Array3};
2577+
///
2578+
/// let mut a = arr2(&[[0, 1], [2, 3]]);
2579+
/// a.permute_axes([1, 0]);
2580+
/// assert_eq!(a, arr2(&[[0, 2], [1, 3]]));
2581+
///
2582+
/// let mut b = Array3::<u8>::zeros((1, 2, 3));
2583+
/// b.permute_axes([1, 0, 2]);
2584+
/// assert_eq!(b.shape(), &[2, 1, 3]);
2585+
/// ```
2586+
#[track_caller]
2587+
pub fn permute_axes<T>(&mut self, axes: T)
2588+
where T: IntoDimension<Dim = D>
2589+
{
2590+
let axes = axes.into_dimension();
2591+
// Ensure that each axis is used exactly once.
2592+
let mut usage_counts = D::zeros(self.ndim());
2593+
for axis in axes.slice() {
2594+
usage_counts[*axis] += 1;
2595+
}
2596+
for count in usage_counts.slice() {
2597+
assert_eq!(*count, 1, "each axis must be listed exactly once");
2598+
}
2599+
2600+
// Create temporary arrays for the new dimensions and strides
2601+
let mut new_dim = D::zeros(self.ndim());
2602+
let mut new_strides = D::zeros(self.ndim());
2603+
2604+
{
2605+
let dim = self.layout.dim.slice();
2606+
let strides = self.layout.strides.slice();
2607+
for (new_axis, &axis) in axes.slice().iter().enumerate() {
2608+
new_dim[new_axis] = dim[axis];
2609+
new_strides[new_axis] = strides[axis];
2610+
}
2611+
}
2612+
2613+
// Update the dimensions and strides in place
2614+
self.layout.dim.slice_mut().copy_from_slice(new_dim.slice());
2615+
self.layout.strides.slice_mut().copy_from_slice(new_strides.slice());
2616+
}
25512617
}
25522618

25532619
impl<A, D: Dimension> ArrayRef<A, D>

0 commit comments

Comments
 (0)