@@ -67,7 +67,12 @@ impl Slice {
6767 }
6868}
6969
70- /// A slice (range with step) or an index.
70+ /// Token to represent a new axis in a slice description.
71+ ///
72+ /// See also the [`s![]`](macro.s!.html) macro.
73+ pub struct NewAxis ;
74+
75+ /// A slice (range with step), an index, or a new axis token.
7176///
7277/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
7378/// `&SliceInfo<[AxisSliceInfo; n], Di, Do>`.
@@ -91,6 +96,10 @@ impl Slice {
9196/// from `a` until the end, in reverse order. It can also be created with
9297/// `AxisSliceInfo::from(a..).step_by(-1)`. The Python equivalent is `[a::-1]`.
9398/// The macro equivalent is `s![a..;-1]`.
99+ ///
100+ /// `AxisSliceInfo::NewAxis` is a new axis of length 1. It can also be created
101+ /// with `AxisSliceInfo::from(NewAxis)`. The Python equivalent is
102+ /// `[np.newaxis]`. The macro equivalent is `s![NewAxis]`.
94103#[ derive( Debug , PartialEq , Eq , Hash ) ]
95104pub enum AxisSliceInfo {
96105 /// A range with step size. `end` is an exclusive index. Negative `begin`
@@ -103,6 +112,8 @@ pub enum AxisSliceInfo {
103112 } ,
104113 /// A single index.
105114 Index ( isize ) ,
115+ /// A new axis of length 1.
116+ NewAxis ,
106117}
107118
108119copy_and_clone ! { AxisSliceInfo }
@@ -124,6 +135,14 @@ impl AxisSliceInfo {
124135 }
125136 }
126137
138+ /// Returns `true` if `self` is a `NewAxis` value.
139+ pub fn is_new_axis ( & self ) -> bool {
140+ match self {
141+ & AxisSliceInfo :: NewAxis => true ,
142+ _ => false ,
143+ }
144+ }
145+
127146 /// Returns a new `AxisSliceInfo` with the given step size (multiplied with
128147 /// the previous step size).
129148 ///
@@ -143,6 +162,7 @@ impl AxisSliceInfo {
143162 step : orig_step * step,
144163 } ,
145164 AxisSliceInfo :: Index ( s) => AxisSliceInfo :: Index ( s) ,
165+ AxisSliceInfo :: NewAxis => AxisSliceInfo :: NewAxis ,
146166 }
147167 }
148168}
@@ -163,6 +183,7 @@ impl fmt::Display for AxisSliceInfo {
163183 write ! ( f, ";{}" , step) ?;
164184 }
165185 }
186+ AxisSliceInfo :: NewAxis => write ! ( f, "NewAxis" ) ?,
166187 }
167188 Ok ( ( ) )
168189 }
@@ -282,6 +303,13 @@ impl_sliceorindex_from_index!(isize);
282303impl_sliceorindex_from_index ! ( usize ) ;
283304impl_sliceorindex_from_index ! ( i32 ) ;
284305
306+ impl From < NewAxis > for AxisSliceInfo {
307+ #[ inline]
308+ fn from ( _: NewAxis ) -> AxisSliceInfo {
309+ AxisSliceInfo :: NewAxis
310+ }
311+ }
312+
285313/// A type that can slice an array of dimension `D`.
286314///
287315/// This trait is unsafe to implement because the implementation must ensure
@@ -402,12 +430,12 @@ where
402430 /// Errors if `Di` or `Do` is not consistent with `indices`.
403431 pub fn new ( indices : T ) -> Result < SliceInfo < T , Di , Do > , ShapeError > {
404432 if let Some ( ndim) = Di :: NDIM {
405- if ndim != indices. as_ref ( ) . len ( ) {
433+ if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| !s . is_new_axis ( ) ) . count ( ) {
406434 return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
407435 }
408436 }
409437 if let Some ( ndim) = Do :: NDIM {
410- if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| s . is_slice ( ) ) . count ( ) {
438+ if ndim != indices. as_ref ( ) . iter ( ) . filter ( |s| !s . is_index ( ) ) . count ( ) {
411439 return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
412440 }
413441 }
@@ -427,8 +455,18 @@ where
427455{
428456 /// Returns the number of dimensions of the input array for
429457 /// [`.slice()`](struct.ArrayBase.html#method.slice).
458+ ///
459+ /// If `Di` is a fixed-size dimension type, then this is equivalent to
460+ /// `Di::NDIM.unwrap()`. Otherwise, the value is calculated by iterating
461+ /// over the `AxisSliceInfo` elements.
430462 pub fn in_ndim ( & self ) -> usize {
431- Di :: NDIM . unwrap_or_else ( || self . indices . as_ref ( ) . len ( ) )
463+ Di :: NDIM . unwrap_or_else ( || {
464+ self . indices
465+ . as_ref ( )
466+ . iter ( )
467+ . filter ( |s| !s. is_new_axis ( ) )
468+ . count ( )
469+ } )
432470 }
433471
434472 /// Returns the number of dimensions after calling
@@ -443,7 +481,7 @@ where
443481 self . indices
444482 . as_ref ( )
445483 . iter ( )
446- . filter ( |s| s . is_slice ( ) )
484+ . filter ( |s| !s . is_index ( ) )
447485 . count ( )
448486 } )
449487 }
@@ -506,6 +544,12 @@ pub trait SliceNextInDim<D1, D2> {
506544 fn next_dim ( & self , PhantomData < D1 > ) -> PhantomData < D2 > ;
507545}
508546
547+ impl < D1 : Dimension > SliceNextInDim < D1 , D1 > for NewAxis {
548+ fn next_dim ( & self , _: PhantomData < D1 > ) -> PhantomData < D1 > {
549+ PhantomData
550+ }
551+ }
552+
509553macro_rules! impl_slicenextindim_larger {
510554 ( ( $( $generics: tt) * ) , $self: ty) => {
511555 impl <D1 : Dimension , $( $generics) ,* > SliceNextInDim <D1 , D1 :: Larger > for $self {
@@ -560,12 +604,13 @@ impl_slicenextoutdim_larger!((T), RangeTo<T>);
560604impl_slicenextoutdim_larger ! ( ( T ) , RangeToInclusive <T >) ;
561605impl_slicenextoutdim_larger ! ( ( ) , RangeFull ) ;
562606impl_slicenextoutdim_larger ! ( ( ) , Slice ) ;
607+ impl_slicenextoutdim_larger ! ( ( ) , NewAxis ) ;
563608
564609/// Slice argument constructor.
565610///
566- /// `s![]` takes a list of ranges/slices/indices, separated by comma, with
567- /// optional step sizes that are separated from the range by a semicolon. It is
568- /// converted into a [`&SliceInfo`] instance.
611+ /// `s![]` takes a list of ranges/slices/indices/new-axes , separated by comma,
612+ /// with optional step sizes that are separated from the range by a semicolon.
613+ /// It is converted into a [`&SliceInfo`] instance.
569614///
570615/// [`&SliceInfo`]: struct.SliceInfo.html
571616///
@@ -584,22 +629,25 @@ impl_slicenextoutdim_larger!((), Slice);
584629/// * *slice*: a [`Slice`] instance to use for slicing that axis.
585630/// * *slice* `;` *step*: a range constructed from the start and end of a [`Slice`]
586631/// instance, with new step size *step*, to use for slicing that axis.
632+ /// * *new-axis*: a [`NewAxis`] instance that represents the creation of a new axis.
587633///
588634/// [`Slice`]: struct.Slice.html
635+ /// [`NewAxis`]: struct.NewAxis.html
589636///
590- /// The number of *axis-slice-info* must match the number of axes in the array.
591- /// *index*, *range*, *slice*, and *step* can be expressions. *index* must be
592- /// of type `isize`, `usize`, or `i32` . *range * must be of type `Range<I>`,
593- /// `RangeTo <I>`, `RangeFrom <I>`, or `RangeFull` where `I` is `isize`, `usize`,
594- /// or `i32`. *step* must be a type that can be converted to `isize` with the
595- /// `as` keyword.
637+ /// The number of *axis-slice-info*, not including *new-axis*, must match the
638+ /// number of axes in the array. *index*, *range*, *slice*, *step*, and
639+ /// *new-axis* can be expressions . *index * must be of type `isize`, `usize`, or
640+ /// `i32`. *range* must be of type `Range <I>`, `RangeTo <I>`, `RangeFrom<I>`, or
641+ /// `RangeFull` where `I` is `isize`, `usize`, or `i32`. *step* must be a type
642+ /// that can be converted to `isize` with the `as` keyword.
596643///
597- /// For example `s![0..4;2, 6, 1..5]` is a slice of the first axis for 0..4
598- /// with step size 2, a subview of the second axis at index 6, and a slice of
599- /// the third axis for 1..5 with default step size 1. The input array must have
600- /// 3 dimensions. The resulting slice would have shape `[2, 4]` for
601- /// [`.slice()`], [`.slice_mut()`], and [`.slice_move()`], and shape
602- /// `[2, 1, 4]` for [`.slice_collapse()`].
644+ /// For example `s![0..4;2, 6, 1..5, NewAxis]` is a slice of the first axis for
645+ /// 0..4 with step size 2, a subview of the second axis at index 6, a slice of
646+ /// the third axis for 1..5 with default step size 1, and a new axis of length
647+ /// 1 at the end of the shape. The input array must have 3 dimensions. The
648+ /// resulting slice would have shape `[2, 4, 1]` for [`.slice()`],
649+ /// [`.slice_mut()`], and [`.slice_move()`], and shape `[2, 1, 4]` for
650+ /// [`.slice_collapse()`].
603651///
604652/// [`.slice()`]: struct.ArrayBase.html#method.slice
605653/// [`.slice_mut()`]: struct.ArrayBase.html#method.slice_mut
@@ -726,11 +774,11 @@ macro_rules! s(
726774 }
727775 }
728776 } ;
729- // convert range/index into AxisSliceInfo
777+ // convert range/index/new-axis into AxisSliceInfo
730778 ( @convert $r: expr) => {
731779 <$crate:: AxisSliceInfo as :: std:: convert:: From <_>>:: from( $r)
732780 } ;
733- // convert range/index and step into AxisSliceInfo
781+ // convert range/index/new-axis and step into AxisSliceInfo
734782 ( @convert $r: expr, $s: expr) => {
735783 <$crate:: AxisSliceInfo as :: std:: convert:: From <_>>:: from( $r) . step_by( $s as isize )
736784 } ;
0 commit comments