Skip to content

Commit d547d4b

Browse files
committed
Add softmax function
1 parent c7391e9 commit d547d4b

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

src/numeric/impl_float_maths.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,91 @@ where
169169
self.mapv(|a| num_traits::clamp(a, min.clone(), max.clone()))
170170
}
171171
}
172+
173+
impl<A, S, D> ArrayBase<S, D>
174+
where
175+
A: Float + 'static,
176+
S: Data<Elem = A>,
177+
D: RemoveAxis,
178+
{
179+
/// Compute the softmax function along the specified axis.
180+
///
181+
/// The softmax function is defined as:
182+
/// ```text
183+
/// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
184+
/// ```
185+
///
186+
/// This function is usually used in machine learning to normalize the output of a neural network to a probability
187+
/// distribution.
188+
/// ```
189+
/// use ndarray::{array, Axis};
190+
///
191+
/// let a = array![[1., 2., 3.], [4., 5., 6.]];
192+
/// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
193+
/// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
194+
/// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
195+
/// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
196+
/// ```
197+
///
198+
/// # Arguments
199+
///
200+
/// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
201+
pub fn softmax(&self, axis: Axis) -> Array<A, D>
202+
{
203+
let mut res = Array::uninit(self.raw_dim());
204+
for (arr, mut res) in self.lanes(axis).into_iter().zip(res.lanes_mut(axis)) {
205+
let max = arr
206+
.iter()
207+
// If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
208+
// will be NaN anyway, so we use an arbitrary ordering.
209+
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
210+
let max = match max {
211+
Some(max) => *max,
212+
None => continue,
213+
};
214+
let sum = arr.fold(A::zero(), |sum, x| sum + (*x - max).exp());
215+
for (i, x) in res.indexed_iter_mut() {
216+
x.write((arr[i] - max).exp() / sum);
217+
}
218+
}
219+
unsafe { res.assume_init() }
220+
}
221+
}
222+
223+
#[cfg(test)]
224+
mod tests
225+
{
226+
use super::*;
227+
use crate::array;
228+
229+
#[test]
230+
fn test_softmax()
231+
{
232+
let a = array![[1., 2., 3.], [4., 5., 6.]];
233+
let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
234+
assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
235+
let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
236+
assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
237+
238+
#[cfg(feature = "approx")]
239+
{
240+
// examples copied from scipy softmax documentation
241+
242+
use approx::assert_relative_eq;
243+
244+
let x = array![[1., 0.5, 0.2, 3.], [1., -1., 7., 3.], [2., 12., 13., 3.]];
245+
246+
let m = x.softmax(Axis(0));
247+
let y = array![[0.211942, 0.00001013, 0.00000275, 0.333333],
248+
[0.211942, 0.00000226, 0.00247262, 0.333333],
249+
[0.576117, 0.999988, 0.997525, 0.333333]];
250+
assert_relative_eq!(m, y, epsilon = 1e-5);
251+
252+
let m = x.softmax(Axis(1));
253+
let y = array![[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
254+
[ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
255+
[ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]];
256+
assert_relative_eq!(m, y, epsilon = 1e-5);
257+
}
258+
}
259+
}

0 commit comments

Comments
 (0)