diff --git a/benches/varstd.rs b/benches/varstd.rs new file mode 100644 index 000000000..9ac79887c --- /dev/null +++ b/benches/varstd.rs @@ -0,0 +1,108 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::arr3; +use ndarray::prelude::*; +use std::iter::FromIterator; + +#[rustfmt::skip] +fn big_array() -> Array2 { + arr2(&[ + [ 92., 53., 51., 94., 27., 69., 11., 13., 62., 42., 73., 83., 2. , 53., 77. ], + [ 65., 56., 11., 32., 95., 66., 88., 10., 37., 8. , 12., 2. , 59., 78., 48. ], + [ 20., 86., 71., 99., 1. , 76., 29., 53., 87., 88., 61., 84., 2. , 87., 90. ], + [ 19., 22., 44., 38., 85., 12., 8. , 38., 53., 46., 80., 70., 62., 14., 8. ], + [ 51., 70., 71., 21., 14., 48., 34., 4. , 27., 55., 60., 95., 1. , 79., 1. ], + [ 13., 23., 78., 97., 57., 16., 81., 31., 88., 15., 78., 95., 93., 9. , 6. ], + [ 68., 58., 4. , 11., 91., 56., 61., 15., 60., 92., 29., 27., 22., 30., 2. ], + [ 53., 70., 89., 42., 59., 79., 63., 61., 86., 48., 40., 50., 23., 18., 55. ], + [ 14., 96., 68., 16., 52., 16., 70., 12., 16., 60., 28., 52., 56., 12., 37. ], + [ 68., 73., 6. , 51., 54., 51., 97., 88., 36., 32., 83., 52., 53., 86., 4. ], + [ 88., 11., 86., 91., 83., 71., 18., 60., 95., 59., 85., 92., 34., 76., 93. ], + [ 81., 18., 47., 26., 53., 64., 53., 12., 55., 92., 76., 22., 81., 80., 21. ], + [ 86., 48., 42., 19., 94., 86., 16., 37., 74., 85., 11., 9. , 80., 2. , 80. ], + [ 51., 43., 55., 56., 49., 77., 78., 94., 80., 23., 72., 67., 58., 95., 95. ], + [ 92., 24., 45., 41., 33., 64., 89., 8. , 75., 42., 32., 61., 19., 11., 61. ], + [ 81., 35., 75., 67., 73., 30., 95., 17., 24., 48., 72., 2. , 46., 14., 50. ], + [ 99., 87., 41., 87., 68., 22., 94., 73., 82., 87., 86., 46., 36., 26., 57. ], + [ 96., 69., 28., 44., 32., 70., 94., 13., 85., 5. , 13., 44., 60., 79., 76. ], + [ 81., 92., 42., 93., 99., 41., 13., 8. , 68., 92., 89., 83., 16., 82., 92. ], + [ 29., 18., 10., 71., 4. , 20., 99., 10., 91., 51., 90., 78., 20., 25., 44. ], + [ 57., 56., 96., 81., 87., 57., 32., 22., 29., 63., 76., 39., 52., 77., 96. ], + [ 88., 2. , 56., 75., 72., 53., 0. , 57., 42., 83., 77., 85., 14., 15., 19. ], + ]) +} + +#[bench] +fn var_into_shape_use_var_axis(bench: &mut Bencher) { + let a = arr3(&[ + [[1., 2.], [1., 3.]], + [[3., 5.], [1., 3.]], + [[5., 7.], [1., 3.]], + ]); + + let len = a.len(); + bench.iter(|| { + let flattened = a + .view() + .into_shape(len) + .expect("into_shape to a.len() can not fail."); + flattened.var_axis(Axis(0), 1.) + }); +} + +#[bench] +fn var_into_shape_use_var_axis_big(bench: &mut Bencher) { + let a = big_array(); + + let len = a.len(); + bench.iter(|| { + let flattened = a + .view() + .into_shape(len) + .expect("into_shape to a.len() can not fail."); + flattened.var_axis(Axis(0), 1.) + }); +} + +#[bench] +fn var_flatten_user_var_axis(bench: &mut Bencher) { + let a = arr3(&[ + [[1., 2.], [1., 3.]], + [[3., 5.], [1., 3.]], + [[5., 7.], [1., 3.]], + ]); + + bench.iter(|| { + let flattened = Array::from_iter(a.iter().map(|&x| x)); + flattened.var_axis(Axis(0), 1.) + }) +} + +#[bench] +fn var_flatten_user_var_axis_big(bench: &mut Bencher) { + let a = big_array(); + bench.iter(|| { + let flattened = Array::from_iter(a.iter().map(|&x| x)); + flattened.var_axis(Axis(0), 1.) + }) +} + +#[bench] +fn var_new(bench: &mut Bencher) { + let a = arr3(&[ + [[1., 2.], [1., 3.]], + [[3., 5.], [1., 3.]], + [[5., 7.], [1., 3.]], + ]); + + bench.iter(|| a.var(1.)) +} + +#[bench] +fn var_new_big(bench: &mut Bencher) { + let a = big_array(); + bench.iter(|| a.var(1.)) +} diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 85f69444d..1a54b471b 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -235,25 +235,13 @@ where A: Float + FromPrimitive, D: RemoveAxis, { - let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); - let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail."); - assert!( - !(ddof < zero || ddof > n), - "`ddof` must not be less than zero or greater than the length of \ - the axis", - ); - let dof = n - ddof; - let mut mean = Array::::zeros(self.dim.remove_axis(axis)); - let mut sum_sq = Array::::zeros(self.dim.remove_axis(axis)); - for (i, subview) in self.axis_iter(axis).enumerate() { - let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail."); - azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) { - let delta = x - *mean; - *mean = *mean + delta / count; - *sum_sq = (x - *mean).mul_add(delta, *sum_sq); + let mut output = Array::zeros(self.dim.remove_axis(axis)); + Zip::from(output.view_mut()) + .and(self.lanes(axis)) + .apply(|o, l| { + *o = l.var(ddof); }); - } - sum_sq.mapv_into(|s| s / dof) + output } /// Return standard deviation along `axis`. @@ -306,6 +294,67 @@ where self.var_axis(axis, ddof).mapv_into(|x| x.sqrt()) } + /// Return variance for the flattened array. + /// + /// This uses the same method as var_axis. + /// + /// # Example + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2.], + /// [3., 4.], + /// [5., 6.]]); + /// + /// let a_flat = a.view().into_shape(6).expect("This must not fail."); + /// assert_eq!(a.var(1.), a_flat.var_axis(Axis(0), 1.).into_scalar()); + /// ``` + pub fn var(&self, ddof: A) -> A + where + A: Float + FromPrimitive, + { + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail."); + assert!( + !(ddof < zero || ddof > n), + "`ddof` must not be less than zero or greater than the length of \ + the axis", + ); + let dof = n - ddof; + let mut mean = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let mut sum_sq = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + for (count, x) in self.iter().enumerate() { + let delta = *x - mean; + mean = mean + delta / A::from_usize(count + 1).unwrap(); + sum_sq = (*x - mean).mul_add(delta, sum_sq); + } + sum_sq / dof + } + + /// Return standard deviation for the flattened array. + /// + /// The standard deviation is computed from the variance. + /// + /// # Example + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2.], + /// [3., 4.], + /// [5., 6.]]); + /// + /// let a_flat = a.view().into_shape(6).expect("This must not fail."); + /// assert_eq!(a.std(1.), a_flat.std_axis(Axis(0), 1.).into_scalar()); + /// ``` + pub fn std(&self, ddof: A) -> A + where + A: Float + FromPrimitive, + { + self.var(ddof).sqrt() + } + /// Return `true` if the arrays' elementwise differences are all within /// the given absolute tolerance, `false` otherwise. /// diff --git a/tests/numeric.rs b/tests/numeric.rs index 7c6f1441e..e1189b047 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -7,7 +7,7 @@ )] use approx::assert_abs_diff_eq; -use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Axis}; +use ndarray::{arr0, arr1, arr2, arr3, array, aview1, Array, Array1, Array2, Array3, Axis}; use std::f64; #[test] @@ -225,3 +225,35 @@ fn std_axis_empty_axis() { assert_eq!(v.shape(), &[2]); v.mapv(|x| assert!(x.is_nan())); } + +#[test] +fn var_var_axis() { + let a = arr3(&[ + [[1., 2.], [1., 3.]], + [[3., 5.], [1., 3.]], + [[5., 7.], [1., 3.]], + ]); + + let a_flat = a + .view() + .into_shape(a.len().clone()) + .expect("into_shape to a.len() must not fail."); + + assert_eq!(a.var(1.), a_flat.var_axis(Axis(0), 1.).into_scalar()); +} + +#[test] +fn std_std_axis() { + let a = arr3(&[ + [[1., 2.], [1., 3.]], + [[3., 5.], [1., 3.]], + [[5., 7.], [1., 3.]], + ]); + + let a_flat = a + .view() + .into_shape(a.len().clone()) + .expect("into_shape to a.len() must not fail."); + + assert_eq!(a.std(1.), a_flat.std_axis(Axis(0), 1.).into_scalar()); +}