@@ -169,3 +169,91 @@ where
169169 self . mapv ( |a| num_traits:: clamp ( a, min. clone ( ) , max. clone ( ) ) )
170170 }
171171}
172+
173+ impl < A , S , D > ArrayBase < S , D >
174+ where
175+ A : Float + ' static ,
176+ S : Data < Elem = A > ,
177+ D : RemoveAxis ,
178+ {
179+ /// Compute the softmax function along the specified axis.
180+ ///
181+ /// The softmax function is defined as:
182+ /// ```text
183+ /// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
184+ /// ```
185+ ///
186+ /// This function is usually used in machine learning to normalize the output of a neural network to a probability
187+ /// distribution.
188+ /// ```
189+ /// use ndarray::{array, Axis};
190+ ///
191+ /// let a = array![[1., 2., 3.], [4., 5., 6.]];
192+ /// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
193+ /// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
194+ /// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
195+ /// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
196+ /// ```
197+ ///
198+ /// # Arguments
199+ ///
200+ /// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
201+ pub fn softmax ( & self , axis : Axis ) -> Array < A , D >
202+ {
203+ let mut res = Array :: uninit ( self . raw_dim ( ) ) ;
204+ for ( arr, mut res) in self . lanes ( axis) . into_iter ( ) . zip ( res. lanes_mut ( axis) ) {
205+ let max = arr
206+ . iter ( )
207+ // If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
208+ // will be NaN anyway, so we use an arbitrary ordering.
209+ . max_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
210+ let max = match max {
211+ Some ( max) => * max,
212+ None => continue ,
213+ } ;
214+ let sum = arr. fold ( A :: zero ( ) , |sum, x| sum + ( * x - max) . exp ( ) ) ;
215+ for ( i, x) in res. indexed_iter_mut ( ) {
216+ x. write ( ( arr[ i] - max) . exp ( ) / sum) ;
217+ }
218+ }
219+ unsafe { res. assume_init ( ) }
220+ }
221+ }
222+
223+ #[ cfg( test) ]
224+ mod tests
225+ {
226+ use super :: * ;
227+ use crate :: array;
228+
229+ #[ test]
230+ fn test_softmax ( )
231+ {
232+ let a = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
233+ let b = a. softmax ( Axis ( 0 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
234+ assert_eq ! ( b, array![ [ 0.05 , 0.05 , 0.05 ] , [ 0.95 , 0.95 , 0.95 ] ] ) ;
235+ let c = a. softmax ( Axis ( 1 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
236+ assert_eq ! ( c, array![ [ 0.09 , 0.24 , 0.67 ] , [ 0.09 , 0.24 , 0.67 ] ] ) ;
237+
238+ #[ cfg( feature = "approx" ) ]
239+ {
240+ // examples copied from scipy softmax documentation
241+
242+ use approx:: assert_relative_eq;
243+
244+ let x = array ! [ [ 1. , 0.5 , 0.2 , 3. ] , [ 1. , -1. , 7. , 3. ] , [ 2. , 12. , 13. , 3. ] ] ;
245+
246+ let m = x. softmax ( Axis ( 0 ) ) ;
247+ let y = array ! [ [ 0.211942 , 0.00001013 , 0.00000275 , 0.333333 ] ,
248+ [ 0.211942 , 0.00000226 , 0.00247262 , 0.333333 ] ,
249+ [ 0.576117 , 0.999988 , 0.997525 , 0.333333 ] ] ;
250+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
251+
252+ let m = x. softmax ( Axis ( 1 ) ) ;
253+ let y = array ! [ [ 1.05877e-01 , 6.42177e-02 , 4.75736e-02 , 7.82332e-01 ] ,
254+ [ 2.42746e-03 , 3.28521e-04 , 9.79307e-01 , 1.79366e-02 ] ,
255+ [ 1.22094e-05 , 2.68929e-01 , 7.31025e-01 , 3.31885e-05 ] ] ;
256+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
257+ }
258+ }
259+ }
0 commit comments