From ed8071170ebdc0e8fb54675f646c0d91016b0b70 Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Wed, 6 Jul 2022 01:08:27 +0200 Subject: [PATCH 01/10] feat: Added implementation for multivariate gamma function `mvgamma` and its log variant `mvlgamma` --- src/function/gamma.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/function/gamma.rs b/src/function/gamma.rs index 9d5124f9..f043c2e2 100644 --- a/src/function/gamma.rs +++ b/src/function/gamma.rs @@ -421,6 +421,24 @@ fn signum(x: f64) -> f64 { } } +/// Computes the [multivariate gamma function](https://en.wikipedia.org/wiki/Multivariate_gamma_function). +pub fn mvgamma(p: i64, a: f64) -> f64 { + let mut res = std::f64::consts::PI.powf((p * (p - 1)) as f64 / 4.0); + for ii in 1..=p { + res *= gamma(a + (1 - ii) as f64 / 2.0); + } + res +} + +/// Computes the log [multivariate gamma function](https://en.wikipedia.org/wiki/Multivariate_gamma_function). +pub fn mvlgamma(p: i64, a: f64) -> f64 { + let mut res = (p * (p - 1)) as f64 * consts::LN_PI / 4.0; + for ii in 1..=p { + res += ln_gamma(a + (1 - ii) as f64 / 2.0); + } + res +} + #[rustfmt::skip] #[cfg(test)] mod tests { @@ -805,4 +823,14 @@ mod tests { assert_almost_eq!(super::inv_digamma(1.6110931485817511237336268416044190359814435699427405), 5.5, 1e-14); assert_almost_eq!(super::inv_digamma(2.2622143570941481235561593642219403924532310597356171), 10.1, 1e-13); } + + #[test] + fn test_ln_mvlgamma() { + assert!(super::mvlgamma(2, f64::NAN).is_nan()); + assert_almost_eq!(super::mvlgamma(0, 1.0), 0.0, 1e-12); + assert_almost_eq!(super::mvlgamma(2, 1.0), 1.1447298858494002, 1e-12); + assert_almost_eq!(super::mvlgamma(3, 1.5), 2.1686775340635553, 1e-12); + assert_almost_eq!(super::mvlgamma(3, 150.0 + 1.0e-12), 1794.2387481112528, 1e-12); + assert_almost_eq!(super::mvlgamma(11, 14.5), 225.2071467353132, 1e-12); + } } From 933b754da0ea31369afd93ae3280cd240d125536 Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Wed, 6 Jul 2022 01:12:53 +0200 Subject: [PATCH 02/10] feat: Added implementation for Wishart distribution --- src/distribution/mod.rs | 1 + src/distribution/wishart.rs | 330 ++++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 src/distribution/wishart.rs diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index e0bec706..be0f6ed2 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -69,6 +69,7 @@ mod uniform; mod weibull; mod ziggurat; mod ziggurat_tables; +mod wishart; use crate::Result; diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs new file mode 100644 index 00000000..7577e65f --- /dev/null +++ b/src/distribution/wishart.rs @@ -0,0 +1,330 @@ +use std::f64::consts::LN_2; +use nalgebra::{Cholesky, DMatrix, DVector, Dynamic}; +use num_traits::Float; +use num_traits::real::Real; +use rand::Rng; +use crate::{Result, StatsError}; +use crate::consts::LN_PI; +use crate::distribution::{ChiSquared, Continuous, Normal, ziggurat}; +use crate::function::gamma::{digamma, mvgamma}; +use crate::statistics::{MeanN, Mode, VarianceN}; +use crate::function::gamma::mvlgamma; + +fn mvdigamma(p: i64, a: f64) -> f64 { + let mut sum = 0.0; + for i in 0..p { + sum += digamma(a - (i as f64) / 2.0); + } + sum +} + + +/// Implements the [Wishart distribution](http://en.wikipedia.org/wiki/Wishart_distribution) +#[derive(Debug, Clone)] +pub struct Wishart { + df: f64, + S: DMatrix, + chol: Cholesky, +} + +impl Wishart { + pub fn new(df: f64, S: DMatrix) -> Result { + if S.nrows() != S.ncols() { + return Err(StatsError::BadParams); + } + if df <= 0.0 || df.is_nan() { + return Err(StatsError::ArgMustBePositive("df must be positive")); + } + if df <= S.nrows() as f64 - 1.0 { + return Err(StatsError::ArgGt("df must be greater than p-1", df)); + } + + match Cholesky::new(S.clone()) { + None => Err(StatsError::BadParams), + Some(chol) => { + Ok(Wishart { df, S, chol }) + } + } + } + + pub fn p(&self) -> usize { + self.S.nrows() + } + + pub fn entropy(&self) -> Option { + let p = self.p() as f64; + Some( + (p + 1.0) / 2.0 * self.chol.determinant().ln() + + 0.5 * p * (p + 1.0) * LN_2 + + mvlgamma(p as i64, self.df / 2.0) + - (self.df - p - 1.0) / 2.0 * mvdigamma(p as i64, self.df / 2.0) + + self.df * p / 2.0 + ) + } +} + +impl MeanN> for Wishart { + fn mean(&self) -> Option> { + Some(self.df * &self.S) + } +} + +impl VarianceN> for Wishart { + fn variance(&self) -> Option> { + Some(self.S.map_with_location(|i, j, x| { + self.df * (self.S[(i, i)] * self.S[(j, j)] + x.powi(2)) + })) + } +} + +impl Mode> for Wishart { + fn mode(&self) -> DMatrix { + (self.df - self.p() as f64 - 1.0) * &self.S + } +} + +impl ::rand::distributions::Distribution> for Wishart { + fn sample(&self, rng: &mut R) -> DMatrix { + let p = self.p(); + let mut A = DMatrix::zeros(p, p); + + for i in 0..p { + A[(i, i)] = ChiSquared::new(self.df - i as f64).unwrap().sample(rng); + } + + for i in 1..p { + for j in 0..i { + A[(i, j)] = ziggurat::sample_std_normal(rng); + } + } + + let L = self.chol.l(); + &L * &A * A.transpose() * L.transpose() + } +} + +impl Continuous, f64> for Wishart { + fn pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let x_det = x.determinant(); + let x_sol = self.chol.solve(&x); + + x_det.powf((self.df - p - 1.0) / 2.0) + * (-0.5 * x_sol.trace()).exp() + / (2.0f64).powf(self.df * p / 2.0) + / self.chol.determinant().powf(self.df / 2.0) + / mvgamma(p as i64, self.df / 2.0) + } + + fn ln_pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let x_lndet = x.determinant().ln(); + let x_sol = self.chol.solve(&x); + + x_lndet * (self.df - p - 1.0) / 2.0 + - 0.5 * x_sol.trace() + - LN_2 * (self.df * p / 2.0) + - self.chol.determinant().ln() * (self.df / 2.0) + - mvlgamma(p as i64, self.df / 2.0) + } +} + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use nalgebra::DMatrix; + use rand::distributions::Distribution; + use rand::rngs::StdRng; + use rand::SeedableRng; + use crate::distribution::Continuous; + use crate::distribution::wishart::Wishart; + use crate::statistics::{MeanN, VarianceN}; + + fn try_create(df: f64, S: DMatrix) -> Wishart + { + let w = Wishart::new(df, S); + assert!(w.is_ok()); + w.unwrap() + } + + fn test_almost( + df: f64, + S: DMatrix, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(Wishart) -> f64, + { + let mvn = try_create(df, S); + let x = eval(mvn); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_mat( + df: f64, + S: DMatrix, + expected: DMatrix, + acc: f64, + eval: F, + ) where + F: FnOnce(Wishart) -> DMatrix, + { + let mvn = try_create(df, S); + let x = eval(mvn); + + for i in 0..x.nrows() { + for j in 0..x.ncols() { + assert_almost_eq!(expected[(i, j)], x[(i, j)], acc); + } + } + } + + #[test] + fn test_variance() { + test_almost_mat( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[4.0, 2.0, 2.0, 4.0]), + 1e-15, |w| w.variance().unwrap(), + ); + test_almost_mat( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 6.17283, 0.618926, 2.88923, + 0.618926, 0.248229, 0.579385, + 2.88923, 0.579385, 5.4093, + ]), + 1e-4, |w| w.variance().unwrap(), + ); + } + + #[test] + fn test_entropy() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 3.9538085820677584, + 1e-14, |w| w.entropy().unwrap(), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), 6.315039109555716, + 1e-12, |w| w.entropy().unwrap(), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 9.9495 + ]), 11.013723204318477, + 1e-12, |w| w.entropy().unwrap(), + ); + } + + #[test] + fn test_mode() { + test_almost_mat( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, 2.0]), + 1e-15, |w| w.mean().unwrap(), + ); + } + + #[test] + fn test_pdf() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.02927491576215958, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.02927491576215958, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 0.028397507846420644, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 2.3729077174800438e-5, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + } + + #[test] + fn test_ln_pdf() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -3.5310242469692907, + 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -3.561453889551996, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -10.648809376818907, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + } + + #[test] + fn test_sample() { + let w = try_create(4.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ])); + + let mut rng = StdRng::seed_from_u64(42); + + for i in 0..100 { + let sample = w.sample(&mut rng); + let l_prob = w.ln_pdf(sample); + assert!(l_prob.is_finite()); + } + } +} \ No newline at end of file From 50bc08bfcaae4ba8ad1df405f5364e712097e62d Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Wed, 6 Jul 2022 23:26:30 +0200 Subject: [PATCH 03/10] feat: Added implementation for Inverse Wishart distribution --- src/distribution/inverse_wishart.rs | 328 ++++++++++++++++++++++++++++ src/distribution/mod.rs | 1 + 2 files changed, 329 insertions(+) create mode 100644 src/distribution/inverse_wishart.rs diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs new file mode 100644 index 00000000..330751d5 --- /dev/null +++ b/src/distribution/inverse_wishart.rs @@ -0,0 +1,328 @@ +use std::f64::consts::LN_2; +use nalgebra::{Cholesky, DMatrix, Dynamic}; +use rand::Rng; +use crate::{Result, StatsError}; +use crate::distribution::Continuous; +use crate::distribution::wishart::Wishart; +use crate::function::gamma::{mvgamma, mvlgamma}; +use crate::statistics::{MeanN, Mode, VarianceN}; + +/// Implements the [Inverse Wishart distribution](http://en.wikipedia.org/wiki/Inverse-Wishart_distribution) +#[derive(Debug, Clone)] +pub struct InverseWishart { + freedom: f64, + scale: DMatrix, + chol: Cholesky, +} + +impl InverseWishart { + pub fn new(freedom: f64, scale: DMatrix) -> Result { + if scale.nrows() != scale.ncols() { + return Err(StatsError::BadParams); + } + if freedom <= 0.0 || freedom.is_nan() { + return Err(StatsError::ArgMustBePositive("degree of freedom must be positive")); + } + if freedom <= scale.nrows() as f64 - 1.0 { + return Err(StatsError::ArgGt("degree of freedom must be greater than p-1", freedom)); + } + + match Cholesky::new(scale.clone()) { + None => Err(StatsError::BadParams), + Some(chol) => { + Ok(InverseWishart { freedom, scale, chol }) + } + } + } + + pub fn p(&self) -> usize { + self.scale.nrows() + } +} + +impl MeanN> for InverseWishart { + fn mean(&self) -> Option> { + Some(&self.scale * (1.0 / (self.freedom - self.p() as f64 - 1.0))) + } +} + +impl VarianceN> for InverseWishart { + fn variance(&self) -> Option> { + let p = self.p() as f64; + + Some(self.scale.map_with_location(|i, j, x| { + let n1 = ((self.freedom - p + 1.0) * x.powi(2)) + ((self.freedom - p - 1.0) * self.scale[(i, i)] * self.scale[(j, j)]); + let n2 = (self.freedom - p) * (self.freedom - p - 1.0) * (self.freedom - p - 1.0) * (self.freedom - p - 3.0); + n1 / n2 + })) + } +} + +impl Mode> for InverseWishart { + fn mode(&self) -> DMatrix { + &self.scale * (1.0 / (self.freedom + self.p() as f64 + 1.0)) + } +} + +impl ::rand::distributions::Distribution> for InverseWishart { + fn sample(&self, rng: &mut R) -> DMatrix { + let w = Wishart::new( + self.freedom, + self.scale.clone().try_inverse().unwrap(), // We already know S is positive definite + ).unwrap(); + + w.sample(rng).pseudo_inverse(1e-4).unwrap() + } +} + +impl Continuous, f64> for InverseWishart { + fn pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let chol = Cholesky::new(x).expect("x is not positive definite"); + let x_det = chol.determinant(); + let sxi = chol.solve(&self.scale); + + x_det.powf(-(self.freedom + p + 1.0) / 2.0) + * (-0.5 * sxi.trace()).exp() + * self.chol.determinant().powf(self.freedom / 2.0) + / (2.0f64).powf(self.freedom * p / 2.0) + / mvgamma(p as i64, self.freedom / 2.0) + } + + fn ln_pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let chol = Cholesky::new(x).expect("x is not positive definite"); + let x_lndet = chol.determinant().ln(); + let sxi = chol.solve(&self.scale); + + x_lndet * -(self.freedom + p + 1.0) / 2.0 + - 0.5 * sxi.trace() + + self.chol.determinant().ln() * (self.freedom / 2.0) + - LN_2 * (self.freedom * p / 2.0) + - mvlgamma(p as i64, self.freedom / 2.0) + } +} + + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use nalgebra::DMatrix; + use rand::distributions::Distribution; + use rand::rngs::StdRng; + use rand::SeedableRng; + use crate::distribution::Continuous; + use crate::distribution::inverse_wishart::InverseWishart; + use crate::statistics::{MeanN, Mode, VarianceN}; + + fn try_create(freedom: f64, scale: DMatrix) -> InverseWishart + { + let w = InverseWishart::new(freedom, scale); + assert!(w.is_ok()); + w.unwrap() + } + + fn test_almost( + freedom: f64, + scale: DMatrix, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(InverseWishart) -> f64, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_mat( + freedom: f64, + scale: DMatrix, + expected: DMatrix, + acc: f64, + eval: F, + ) where + F: FnOnce(InverseWishart) -> DMatrix, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + + for i in 0..x.nrows() { + for j in 0..x.ncols() { + assert_almost_eq!(expected[(i, j)], x[(i, j)], acc); + } + } + } + + #[test] + fn test_mean() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.333333, 0.0, + 0.0, 0.333333 + ]), + 1e-5, |w| w.mean().unwrap(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.3381, 0.0, 0.0, + 0.0, 0.0678, 0.0, + 0.0, 0.0, 0.3165, + ]), + 1e-5, |w| w.mean().unwrap(), + ); + } + + #[test] + fn test_mode() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.111111, 0.0, + 0.0, 0.111111 + ]), + 1e-5, |w| w.mode(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.0922091, 0.0, 0.0, + 0.0, 0.0184909, 0.0, + 0.0, 0.0, 0.0863182, + ]), + 1e-5, |w| w.mode(), + ); + } + + #[test] + fn test_variance() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.222222, 0.0833333, + 0.0833333, 0.222222 + ]), + 1e-5, |w| w.variance().unwrap(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.228623, 0.0171924, 0.0802565, + 0.0171924, 0.00919368, 0.016094, + 0.0802565, 0.016094, 0.200344, + ]), + 1e-5, |w| w.variance().unwrap(), + ); + } + + #[test] + fn test_pdf() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.02927491576215958, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -3.5310242469692907, + 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + + test_almost( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.0012197881567566496, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -6.709078077317236, + 1e-13, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 0.04313330476055326, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -3.1434598480128795, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 3.3174174014586637e-7, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -14.91890906187113, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + } + + #[test] + fn test_sample() { + let w = try_create(4.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ])); + + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let sample = w.sample(&mut rng); + let l_prob = w.ln_pdf(sample); + assert!(l_prob.is_finite()); + } + } +} \ No newline at end of file diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index be0f6ed2..c97a4159 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -70,6 +70,7 @@ mod weibull; mod ziggurat; mod ziggurat_tables; mod wishart; +mod inverse_wishart; use crate::Result; From 016a2114603e8580c5f988a7eb49e0f7be81745f Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Wed, 6 Jul 2022 23:27:22 +0200 Subject: [PATCH 04/10] refactor: Renamed df and S to freedom and scale, respectively. In wishart distribution implementation --- src/distribution/wishart.rs | 147 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 72 deletions(-) diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs index 7577e65f..43c470cb 100644 --- a/src/distribution/wishart.rs +++ b/src/distribution/wishart.rs @@ -22,33 +22,33 @@ fn mvdigamma(p: i64, a: f64) -> f64 { /// Implements the [Wishart distribution](http://en.wikipedia.org/wiki/Wishart_distribution) #[derive(Debug, Clone)] pub struct Wishart { - df: f64, - S: DMatrix, + freedom: f64, + scale: DMatrix, chol: Cholesky, } impl Wishart { - pub fn new(df: f64, S: DMatrix) -> Result { - if S.nrows() != S.ncols() { + pub fn new(freedom: f64, scale: DMatrix) -> Result { + if scale.nrows() != scale.ncols() { return Err(StatsError::BadParams); } - if df <= 0.0 || df.is_nan() { - return Err(StatsError::ArgMustBePositive("df must be positive")); + if freedom <= 0.0 || freedom.is_nan() { + return Err(StatsError::ArgMustBePositive("degree of freedom must be positive")); } - if df <= S.nrows() as f64 - 1.0 { - return Err(StatsError::ArgGt("df must be greater than p-1", df)); + if freedom <= scale.nrows() as f64 - 1.0 { + return Err(StatsError::ArgGt("degree of freedom must be greater than p-1", freedom)); } - match Cholesky::new(S.clone()) { + match Cholesky::new(scale.clone()) { None => Err(StatsError::BadParams), Some(chol) => { - Ok(Wishart { df, S, chol }) + Ok(Wishart { freedom, scale, chol }) } } } pub fn p(&self) -> usize { - self.S.nrows() + self.scale.nrows() } pub fn entropy(&self) -> Option { @@ -56,50 +56,50 @@ impl Wishart { Some( (p + 1.0) / 2.0 * self.chol.determinant().ln() + 0.5 * p * (p + 1.0) * LN_2 - + mvlgamma(p as i64, self.df / 2.0) - - (self.df - p - 1.0) / 2.0 * mvdigamma(p as i64, self.df / 2.0) - + self.df * p / 2.0 + + mvlgamma(p as i64, self.freedom / 2.0) + - (self.freedom - p - 1.0) / 2.0 * mvdigamma(p as i64, self.freedom / 2.0) + + self.freedom * p / 2.0 ) } } impl MeanN> for Wishart { fn mean(&self) -> Option> { - Some(self.df * &self.S) + Some(self.freedom * &self.scale) } } impl VarianceN> for Wishart { fn variance(&self) -> Option> { - Some(self.S.map_with_location(|i, j, x| { - self.df * (self.S[(i, i)] * self.S[(j, j)] + x.powi(2)) + Some(self.scale.map_with_location(|i, j, x| { + self.freedom * (self.scale[(i, i)] * self.scale[(j, j)] + x.powi(2)) })) } } impl Mode> for Wishart { fn mode(&self) -> DMatrix { - (self.df - self.p() as f64 - 1.0) * &self.S + (self.freedom - self.p() as f64 - 1.0) * &self.scale } } impl ::rand::distributions::Distribution> for Wishart { fn sample(&self, rng: &mut R) -> DMatrix { let p = self.p(); - let mut A = DMatrix::zeros(p, p); + let mut a = DMatrix::zeros(p, p); for i in 0..p { - A[(i, i)] = ChiSquared::new(self.df - i as f64).unwrap().sample(rng); + a[(i, i)] = ChiSquared::new(self.freedom - i as f64).unwrap().sample(rng); } for i in 1..p { for j in 0..i { - A[(i, j)] = ziggurat::sample_std_normal(rng); + a[(i, j)] = ziggurat::sample_std_normal(rng); } } - let L = self.chol.l(); - &L * &A * A.transpose() * L.transpose() + let l = self.chol.l(); + &l * &a * a.transpose() * l.transpose() } } @@ -107,25 +107,25 @@ impl Continuous, f64> for Wishart { fn pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let x_det = x.determinant(); - let x_sol = self.chol.solve(&x); + let six = self.chol.solve(&x); - x_det.powf((self.df - p - 1.0) / 2.0) - * (-0.5 * x_sol.trace()).exp() - / (2.0f64).powf(self.df * p / 2.0) - / self.chol.determinant().powf(self.df / 2.0) - / mvgamma(p as i64, self.df / 2.0) + x_det.powf((self.freedom - p - 1.0) / 2.0) + * (-0.5 * six.trace()).exp() + / (2.0f64).powf(self.freedom * p / 2.0) + / self.chol.determinant().powf(self.freedom / 2.0) + / mvgamma(p as i64, self.freedom / 2.0) } fn ln_pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let x_lndet = x.determinant().ln(); - let x_sol = self.chol.solve(&x); + let six = self.chol.solve(&x); - x_lndet * (self.df - p - 1.0) / 2.0 - - 0.5 * x_sol.trace() - - LN_2 * (self.df * p / 2.0) - - self.chol.determinant().ln() * (self.df / 2.0) - - mvlgamma(p as i64, self.df / 2.0) + x_lndet * (self.freedom - p - 1.0) / 2.0 + - 0.5 * six.trace() + - LN_2 * (self.freedom * p / 2.0) + - self.chol.determinant().ln() * (self.freedom / 2.0) + - mvlgamma(p as i64, self.freedom / 2.0) } } @@ -138,39 +138,39 @@ mod tests { use rand::SeedableRng; use crate::distribution::Continuous; use crate::distribution::wishart::Wishart; - use crate::statistics::{MeanN, VarianceN}; + use crate::statistics::{MeanN, Mode, VarianceN}; - fn try_create(df: f64, S: DMatrix) -> Wishart + fn try_create(freedom: f64, scale: DMatrix) -> Wishart { - let w = Wishart::new(df, S); + let w = Wishart::new(freedom, scale); assert!(w.is_ok()); w.unwrap() } fn test_almost( - df: f64, - S: DMatrix, + freedom: f64, + scale: DMatrix, expected: f64, acc: f64, eval: F, ) where F: FnOnce(Wishart) -> f64, { - let mvn = try_create(df, S); + let mvn = try_create(freedom, scale); let x = eval(mvn); assert_almost_eq!(expected, x, acc); } fn test_almost_mat( - df: f64, - S: DMatrix, + freedom: f64, + scale: DMatrix, expected: DMatrix, acc: f64, eval: F, ) where F: FnOnce(Wishart) -> DMatrix, { - let mvn = try_create(df, S); + let mvn = try_create(freedom, scale); let x = eval(mvn); for i in 0..x.nrows() { @@ -180,6 +180,24 @@ mod tests { } } + #[test] + fn test_mean() { + test_almost_mat( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, 2.0]), + 1e-15, |w| w.mean().unwrap(), + ); + } + + #[test] + fn test_mode() { + test_almost_mat( + 7.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[4.0, 0.0, 0.0, 4.0]), + 1e-15, |w| w.mode(), + ); + } + #[test] fn test_variance() { test_almost_mat( @@ -227,15 +245,6 @@ mod tests { ); } - #[test] - fn test_mode() { - test_almost_mat( - 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), - DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, 2.0]), - 1e-15, |w| w.mean().unwrap(), - ); - } - #[test] fn test_pdf() { test_almost( @@ -245,9 +254,10 @@ mod tests { ); test_almost( 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), - 0.02927491576215958, - 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + -3.5310242469692907, + 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), ); + test_almost( 3.0, DMatrix::from_row_slice(3, 3, &[ 1.0143, 0.0000, 0.0000, @@ -267,33 +277,26 @@ mod tests { 0.0000, 0.2034, 0.0000, 0.0000, 0.0000, 0.9495 ]), - 2.3729077174800438e-5, - 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ - 1.7121, 0.0000, 0.0000, + -3.561453889551996, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, 0.0000, 0.4010, 0.0000, - 0.0000, 0.0000, 9.5627, + 0.0000, 0.0000, 0.5627, ])), ); - } - #[test] - fn test_ln_pdf() { - test_almost( - 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), - -3.5310242469692907, - 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), - ); + test_almost( 3.0, DMatrix::from_row_slice(3, 3, &[ 1.0143, 0.0000, 0.0000, 0.0000, 0.2034, 0.0000, 0.0000, 0.0000, 0.9495 ]), - -3.561453889551996, - 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ - 0.7121, 0.0000, 0.0000, + 2.3729077174800438e-5, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, 0.0000, 0.4010, 0.0000, - 0.0000, 0.0000, 0.5627, + 0.0000, 0.0000, 9.5627, ])), ); test_almost( @@ -321,7 +324,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(42); - for i in 0..100 { + for _ in 0..100 { let sample = w.sample(&mut rng); let l_prob = w.ln_pdf(sample); assert!(l_prob.is_finite()); From 5f6468d66bef846417b9063b831b71822d7da86f Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Wed, 6 Jul 2022 23:32:10 +0200 Subject: [PATCH 05/10] change: Added freedom and scale accessors for wishart and inverse wishart distribution implementation --- src/distribution/inverse_wishart.rs | 8 ++++++++ src/distribution/wishart.rs | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs index 330751d5..1d32a84f 100644 --- a/src/distribution/inverse_wishart.rs +++ b/src/distribution/inverse_wishart.rs @@ -35,6 +35,14 @@ impl InverseWishart { } } + pub fn freedom(&self) -> f64 { + self.freedom + } + + pub fn scale(&self) -> &DMatrix { + &self.scale + } + pub fn p(&self) -> usize { self.scale.nrows() } diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs index 43c470cb..beff2111 100644 --- a/src/distribution/wishart.rs +++ b/src/distribution/wishart.rs @@ -47,6 +47,14 @@ impl Wishart { } } + pub fn freedom(&self) -> f64 { + self.freedom + } + + pub fn scale(&self) -> &DMatrix { + &self.scale + } + pub fn p(&self) -> usize { self.scale.nrows() } From b1e91924309828b1aa095628501e811853b7a11c Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Thu, 7 Jul 2022 00:20:17 +0200 Subject: [PATCH 06/10] docs: Added docstrings to Wishart and Inverse Wishart distributions. --- src/distribution/inverse_wishart.rs | 114 ++++++++++++++++++++++++-- src/distribution/mod.rs | 2 + src/distribution/wishart.rs | 119 +++++++++++++++++++++++++++- 3 files changed, 225 insertions(+), 10 deletions(-) diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs index 1d32a84f..96aefea2 100644 --- a/src/distribution/inverse_wishart.rs +++ b/src/distribution/inverse_wishart.rs @@ -8,6 +8,20 @@ use crate::function::gamma::{mvgamma, mvlgamma}; use crate::statistics::{MeanN, Mode, VarianceN}; /// Implements the [Inverse Wishart distribution](http://en.wikipedia.org/wiki/Inverse-Wishart_distribution) +/// +/// # Example +/// ``` +/// use nalgebra::DMatrix; +/// use statrs::distribution::{InverseWishart, Continuous}; +/// use statrs::statistics::Distribution; +/// use statrs::prec; +/// +/// let result = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_ok()); +/// +/// let result = InverseWishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_err()); +/// ``` #[derive(Debug, Clone)] pub struct InverseWishart { freedom: f64, @@ -16,6 +30,28 @@ pub struct InverseWishart { } impl InverseWishart { + /// Constructs a new Inverse Wishart distribution with a degrees of freedom (ν) of `freedom` + /// and a scale matrix (ψ) of `scale` + /// + /// # Errors + /// + /// Returns an error if `scale` matrix is not square. + /// Returns an error if `freedom` is `NaN`. + /// Returns an error if `freedom <= rows(scale) - 1` + /// Returns an error if `scale` is not positive definite. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let result = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_ok()); + /// + /// let result = InverseWishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_err()); + /// ``` pub fn new(freedom: f64, scale: DMatrix) -> Result { if scale.nrows() != scale.ncols() { return Err(StatsError::BadParams); @@ -35,26 +71,75 @@ impl InverseWishart { } } + /// Returns the degrees of freedom of + /// the Inverse Wishart distribution. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.freedom(), 2.0); + /// ``` pub fn freedom(&self) -> f64 { self.freedom } + /// Returns the scale of the Inverse Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.scale(), DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// ``` pub fn scale(&self) -> &DMatrix { &self.scale } + /// Returns the dimensionality of the Inverse Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.p(), 2); + /// ``` pub fn p(&self) -> usize { self.scale.nrows() } } impl MeanN> for InverseWishart { + /// Returns the mean of the Inverse Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// ψ/(ν - p - 1) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix, and `p` is the dimensionality of the distribution. fn mean(&self) -> Option> { - Some(&self.scale * (1.0 / (self.freedom - self.p() as f64 - 1.0))) + if self.freedom > self.p() as f64 + 1.0 { + Some(&self.scale / (self.freedom - self.p() as f64 - 1.0)) + } else { + None + } } } impl VarianceN> for InverseWishart { + /// Returns the variance of the Inverse Wishart distribution + /// See [formula](https://en.wikipedia.org/wiki/Inverse-Wishart_distribution#Moments) fn variance(&self) -> Option> { let p = self.p() as f64; @@ -66,9 +151,22 @@ impl VarianceN> for InverseWishart { } } -impl Mode> for InverseWishart { - fn mode(&self) -> DMatrix { - &self.scale * (1.0 / (self.freedom + self.p() as f64 + 1.0)) +impl Mode>> for InverseWishart { + /// Returns the median of the Inverse Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// ψ/(ν + p + 1) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix, and `p` is the dimensionality of the distribution. + fn mode(&self) -> Option> { + if self.freedom > self.p() as f64 + 1.0 { + Some(&self.scale * (1.0 / (self.freedom + self.p() as f64 + 1.0))) + } else { + None + } } } @@ -84,6 +182,8 @@ impl ::rand::distributions::Distribution> for InverseWishart { } impl Continuous, f64> for InverseWishart { + /// Calculates the probability density function for the Inverse Wishart + /// distribution at `x` fn pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let chol = Cholesky::new(x).expect("x is not positive definite"); @@ -97,6 +197,8 @@ impl Continuous, f64> for InverseWishart { / mvgamma(p as i64, self.freedom / 2.0) } + /// Calculates the log probability density function for the Inverse Wishart + /// distribution at `x` fn ln_pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let chol = Cholesky::new(x).expect("x is not positive definite"); @@ -196,7 +298,7 @@ mod tests { 0.111111, 0.0, 0.0, 0.111111 ]), - 1e-5, |w| w.mode(), + 1e-5, |w| w.mode().unwrap(), ); test_almost_mat( 7.0, DMatrix::from_row_slice(3, 3, &[ @@ -209,7 +311,7 @@ mod tests { 0.0, 0.0184909, 0.0, 0.0, 0.0, 0.0863182, ]), - 1e-5, |w| w.mode(), + 1e-5, |w| w.mode().unwrap(), ); } diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index c97a4159..91008f58 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -34,6 +34,8 @@ pub use self::students_t::StudentsT; pub use self::triangular::Triangular; pub use self::uniform::Uniform; pub use self::weibull::Weibull; +pub use self::wishart::Wishart; +pub use self::inverse_wishart::InverseWishart; mod bernoulli; mod beta; diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs index beff2111..681e6abe 100644 --- a/src/distribution/wishart.rs +++ b/src/distribution/wishart.rs @@ -20,6 +20,20 @@ fn mvdigamma(p: i64, a: f64) -> f64 { /// Implements the [Wishart distribution](http://en.wikipedia.org/wiki/Wishart_distribution) +/// +/// # Example +/// ``` +/// use nalgebra::DMatrix; +/// use statrs::distribution::{Wishart, Continuous}; +/// use statrs::statistics::Distribution; +/// use statrs::prec; +/// +/// let result = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_ok()); +/// +/// let result = Wishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_err()); +/// ``` #[derive(Debug, Clone)] pub struct Wishart { freedom: f64, @@ -28,6 +42,28 @@ pub struct Wishart { } impl Wishart { + /// Constructs a new Wishart distribution with a degrees of freedom (ν) of `freedom` + /// and a scale matrix (ψ) of `scale` + /// + /// # Errors + /// + /// Returns an error if `scale` matrix is not square. + /// Returns an error if `freedom` is `NaN`. + /// Returns an error if `freedom <= rows(scale) - 1` + /// Returns an error if `scale` is not positive definite. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let result = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_ok()); + /// + /// let result = Wishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_err()); + /// ``` pub fn new(freedom: f64, scale: DMatrix) -> Result { if scale.nrows() != scale.ncols() { return Err(StatsError::BadParams); @@ -47,18 +83,54 @@ impl Wishart { } } + /// Returns the degrees of freedom of + /// the Wishart distribution. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.freedom(), 2.0); + /// ``` pub fn freedom(&self) -> f64 { self.freedom } + /// Returns the scale of the Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.scale(), DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// ``` pub fn scale(&self) -> &DMatrix { &self.scale } + /// Returns the dimensionality of the Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(n.p(), 2); + /// ``` pub fn p(&self) -> usize { self.scale.nrows() } + /// Returns the entropy of the Wishart distribution. + /// See [formula](https://en.wikipedia.org/wiki/Wishart_distribution#Entropy) pub fn entropy(&self) -> Option { let p = self.p() as f64; Some( @@ -72,12 +144,30 @@ impl Wishart { } impl MeanN> for Wishart { + /// Returns the mean of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// νψ + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix fn mean(&self) -> Option> { Some(self.freedom * &self.scale) } } impl VarianceN> for Wishart { + /// Returns the variance of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// Var(x_ij) = ν(ψ_ij^2 + ψ_ii ψ_jj) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix fn variance(&self) -> Option> { Some(self.scale.map_with_location(|i, j, x| { self.freedom * (self.scale[(i, i)] * self.scale[(j, j)] + x.powi(2)) @@ -85,9 +175,26 @@ impl VarianceN> for Wishart { } } -impl Mode> for Wishart { - fn mode(&self) -> DMatrix { - (self.freedom - self.p() as f64 - 1.0) * &self.scale +impl Mode>> for Wishart { + /// Returns the median of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// if k == 1 { + /// 0 + /// } else { + /// λ((k - 1) / k)^(1 / k) + /// } + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix + fn mode(&self) -> Option> { + if self.freedom >= self.p() as f64 + 1.0 { + Some((self.freedom - self.p() as f64 - 1.0) * &self.scale) + } else { + None + } } } @@ -112,6 +219,8 @@ impl ::rand::distributions::Distribution> for Wishart { } impl Continuous, f64> for Wishart { + /// Calculates the probability density function for the Wishart + /// distribution at `x` fn pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let x_det = x.determinant(); @@ -124,6 +233,8 @@ impl Continuous, f64> for Wishart { / mvgamma(p as i64, self.freedom / 2.0) } + /// Calculates the log probability density function for the Wishart + /// distribution at `x` fn ln_pdf(&self, x: DMatrix) -> f64 { let p = self.p() as f64; let x_lndet = x.determinant().ln(); @@ -202,7 +313,7 @@ mod tests { test_almost_mat( 7.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), DMatrix::from_row_slice(2, 2, &[4.0, 0.0, 0.0, 4.0]), - 1e-15, |w| w.mode(), + 1e-15, |w| w.mode().unwrap(), ); } From fa31a22aec42684b98d2151defc2894bc0d6deb3 Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Fri, 8 Jul 2022 00:16:09 +0200 Subject: [PATCH 07/10] fix: Updated inverse wishart sampling to make the matrix symmetric as pseudo inverse doesn't do this due to approximation --- src/distribution/inverse_wishart.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs index 96aefea2..a63b7c3b 100644 --- a/src/distribution/inverse_wishart.rs +++ b/src/distribution/inverse_wishart.rs @@ -177,7 +177,7 @@ impl ::rand::distributions::Distribution> for InverseWishart { self.scale.clone().try_inverse().unwrap(), // We already know S is positive definite ).unwrap(); - w.sample(rng).pseudo_inverse(1e-4).unwrap() + w.sample(rng).pseudo_inverse(1e-4).unwrap().symmetric_part() } } From db1aad11a94d3142646d62863d8d899af3685300 Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Sat, 17 Sep 2022 15:43:40 +0200 Subject: [PATCH 08/10] fix: Fixed use use of chi^2 to chi in wishart sampling --- src/distribution/wishart.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs index 681e6abe..bc0cce41 100644 --- a/src/distribution/wishart.rs +++ b/src/distribution/wishart.rs @@ -204,7 +204,7 @@ impl ::rand::distributions::Distribution> for Wishart { let mut a = DMatrix::zeros(p, p); for i in 0..p { - a[(i, i)] = ChiSquared::new(self.freedom - i as f64).unwrap().sample(rng); + a[(i, i)] = ChiSquared::new(self.freedom - i as f64).unwrap().sample(rng).sqrt(); } for i in 1..p { @@ -213,8 +213,8 @@ impl ::rand::distributions::Distribution> for Wishart { } } - let l = self.chol.l(); - &l * &a * a.transpose() * l.transpose() + let l = self.chol.l() * &a; + &l * &l.transpose() } } From 85e98987f4d3a7978692f59942288577a4dfb3ce Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Sat, 17 Sep 2022 16:29:56 +0200 Subject: [PATCH 09/10] fix: Fixed wishart distribution examples --- src/distribution/inverse_wishart.rs | 8 ++++---- src/distribution/wishart.rs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs index a63b7c3b..684435b4 100644 --- a/src/distribution/inverse_wishart.rs +++ b/src/distribution/inverse_wishart.rs @@ -81,7 +81,7 @@ impl InverseWishart { /// use statrs::distribution::InverseWishart; /// /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.freedom(), 2.0); + /// assert_eq!(w.freedom(), 2.0); /// ``` pub fn freedom(&self) -> f64 { self.freedom @@ -96,7 +96,7 @@ impl InverseWishart { /// use statrs::distribution::InverseWishart; /// /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.scale(), DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert_eq!(w.scale(), &DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); /// ``` pub fn scale(&self) -> &DMatrix { &self.scale @@ -111,7 +111,7 @@ impl InverseWishart { /// use statrs::distribution::InverseWishart; /// /// let w = InverseWishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.p(), 2); + /// assert_eq!(w.p(), 2); /// ``` pub fn p(&self) -> usize { self.scale.nrows() @@ -435,4 +435,4 @@ mod tests { assert!(l_prob.is_finite()); } } -} \ No newline at end of file +} diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs index bc0cce41..ab2615cf 100644 --- a/src/distribution/wishart.rs +++ b/src/distribution/wishart.rs @@ -93,7 +93,7 @@ impl Wishart { /// use statrs::distribution::Wishart; /// /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.freedom(), 2.0); + /// assert_eq!(w.freedom(), 2.0); /// ``` pub fn freedom(&self) -> f64 { self.freedom @@ -108,7 +108,7 @@ impl Wishart { /// use statrs::distribution::Wishart; /// /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.scale(), DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert_eq!(w.scale(), &DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); /// ``` pub fn scale(&self) -> &DMatrix { &self.scale @@ -123,7 +123,7 @@ impl Wishart { /// use statrs::distribution::Wishart; /// /// let w = Wishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); - /// assert_eq!(n.p(), 2); + /// assert_eq!(w.p(), 2); /// ``` pub fn p(&self) -> usize { self.scale.nrows() @@ -449,4 +449,4 @@ mod tests { assert!(l_prob.is_finite()); } } -} \ No newline at end of file +} From fbed08955c6c6ef4f762b18f4bf8131d55b50d3f Mon Sep 17 00:00:00 2001 From: Egor Dmitriev Date: Mon, 19 Sep 2022 12:15:39 +0200 Subject: [PATCH 10/10] fix: Updted inverse wishart to compute chol inverse instead of an approximation --- src/distribution/inverse_wishart.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs index 684435b4..68da83c9 100644 --- a/src/distribution/inverse_wishart.rs +++ b/src/distribution/inverse_wishart.rs @@ -176,8 +176,9 @@ impl ::rand::distributions::Distribution> for InverseWishart { self.freedom, self.scale.clone().try_inverse().unwrap(), // We already know S is positive definite ).unwrap(); + let s = w.sample(rng); - w.sample(rng).pseudo_inverse(1e-4).unwrap().symmetric_part() + s.cholesky().unwrap().inverse().symmetric_part() } }