diff --git a/src/distribution/inverse_wishart.rs b/src/distribution/inverse_wishart.rs new file mode 100644 index 00000000..68da83c9 --- /dev/null +++ b/src/distribution/inverse_wishart.rs @@ -0,0 +1,439 @@ +use std::f64::consts::LN_2; +use nalgebra::{Cholesky, DMatrix, Dynamic}; +use rand::Rng; +use crate::{Result, StatsError}; +use crate::distribution::Continuous; +use crate::distribution::wishart::Wishart; +use crate::function::gamma::{mvgamma, mvlgamma}; +use crate::statistics::{MeanN, Mode, VarianceN}; + +/// Implements the [Inverse Wishart distribution](http://en.wikipedia.org/wiki/Inverse-Wishart_distribution) +/// +/// # Example +/// ``` +/// use nalgebra::DMatrix; +/// use statrs::distribution::{InverseWishart, Continuous}; +/// use statrs::statistics::Distribution; +/// use statrs::prec; +/// +/// let result = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_ok()); +/// +/// let result = InverseWishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_err()); +/// ``` +#[derive(Debug, Clone)] +pub struct InverseWishart { + freedom: f64, + scale: DMatrix, + chol: Cholesky, +} + +impl InverseWishart { + /// Constructs a new Inverse Wishart distribution with a degrees of freedom (ν) of `freedom` + /// and a scale matrix (ψ) of `scale` + /// + /// # Errors + /// + /// Returns an error if `scale` matrix is not square. + /// Returns an error if `freedom` is `NaN`. + /// Returns an error if `freedom <= rows(scale) - 1` + /// Returns an error if `scale` is not positive definite. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let result = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_ok()); + /// + /// let result = InverseWishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_err()); + /// ``` + pub fn new(freedom: f64, scale: DMatrix) -> Result { + if scale.nrows() != scale.ncols() { + return Err(StatsError::BadParams); + } + if freedom <= 0.0 || freedom.is_nan() { + return Err(StatsError::ArgMustBePositive("degree of freedom must be positive")); + } + if freedom <= scale.nrows() as f64 - 1.0 { + return Err(StatsError::ArgGt("degree of freedom must be greater than p-1", freedom)); + } + + match Cholesky::new(scale.clone()) { + None => Err(StatsError::BadParams), + Some(chol) => { + Ok(InverseWishart { freedom, scale, chol }) + } + } + } + + /// Returns the degrees of freedom of + /// the Inverse Wishart distribution. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.freedom(), 2.0); + /// ``` + pub fn freedom(&self) -> f64 { + self.freedom + } + + /// Returns the scale of the Inverse Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.scale(), &DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// ``` + pub fn scale(&self) -> &DMatrix { + &self.scale + } + + /// Returns the dimensionality of the Inverse Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::InverseWishart; + /// + /// let w = InverseWishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.p(), 2); + /// ``` + pub fn p(&self) -> usize { + self.scale.nrows() + } +} + +impl MeanN> for InverseWishart { + /// Returns the mean of the Inverse Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// ψ/(ν - p - 1) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix, and `p` is the dimensionality of the distribution. + fn mean(&self) -> Option> { + if self.freedom > self.p() as f64 + 1.0 { + Some(&self.scale / (self.freedom - self.p() as f64 - 1.0)) + } else { + None + } + } +} + +impl VarianceN> for InverseWishart { + /// Returns the variance of the Inverse Wishart distribution + /// See [formula](https://en.wikipedia.org/wiki/Inverse-Wishart_distribution#Moments) + fn variance(&self) -> Option> { + let p = self.p() as f64; + + Some(self.scale.map_with_location(|i, j, x| { + let n1 = ((self.freedom - p + 1.0) * x.powi(2)) + ((self.freedom - p - 1.0) * self.scale[(i, i)] * self.scale[(j, j)]); + let n2 = (self.freedom - p) * (self.freedom - p - 1.0) * (self.freedom - p - 1.0) * (self.freedom - p - 3.0); + n1 / n2 + })) + } +} + +impl Mode>> for InverseWishart { + /// Returns the median of the Inverse Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// ψ/(ν + p + 1) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix, and `p` is the dimensionality of the distribution. + fn mode(&self) -> Option> { + if self.freedom > self.p() as f64 + 1.0 { + Some(&self.scale * (1.0 / (self.freedom + self.p() as f64 + 1.0))) + } else { + None + } + } +} + +impl ::rand::distributions::Distribution> for InverseWishart { + fn sample(&self, rng: &mut R) -> DMatrix { + let w = Wishart::new( + self.freedom, + self.scale.clone().try_inverse().unwrap(), // We already know S is positive definite + ).unwrap(); + let s = w.sample(rng); + + s.cholesky().unwrap().inverse().symmetric_part() + } +} + +impl Continuous, f64> for InverseWishart { + /// Calculates the probability density function for the Inverse Wishart + /// distribution at `x` + fn pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let chol = Cholesky::new(x).expect("x is not positive definite"); + let x_det = chol.determinant(); + let sxi = chol.solve(&self.scale); + + x_det.powf(-(self.freedom + p + 1.0) / 2.0) + * (-0.5 * sxi.trace()).exp() + * self.chol.determinant().powf(self.freedom / 2.0) + / (2.0f64).powf(self.freedom * p / 2.0) + / mvgamma(p as i64, self.freedom / 2.0) + } + + /// Calculates the log probability density function for the Inverse Wishart + /// distribution at `x` + fn ln_pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let chol = Cholesky::new(x).expect("x is not positive definite"); + let x_lndet = chol.determinant().ln(); + let sxi = chol.solve(&self.scale); + + x_lndet * -(self.freedom + p + 1.0) / 2.0 + - 0.5 * sxi.trace() + + self.chol.determinant().ln() * (self.freedom / 2.0) + - LN_2 * (self.freedom * p / 2.0) + - mvlgamma(p as i64, self.freedom / 2.0) + } +} + + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use nalgebra::DMatrix; + use rand::distributions::Distribution; + use rand::rngs::StdRng; + use rand::SeedableRng; + use crate::distribution::Continuous; + use crate::distribution::inverse_wishart::InverseWishart; + use crate::statistics::{MeanN, Mode, VarianceN}; + + fn try_create(freedom: f64, scale: DMatrix) -> InverseWishart + { + let w = InverseWishart::new(freedom, scale); + assert!(w.is_ok()); + w.unwrap() + } + + fn test_almost( + freedom: f64, + scale: DMatrix, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(InverseWishart) -> f64, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_mat( + freedom: f64, + scale: DMatrix, + expected: DMatrix, + acc: f64, + eval: F, + ) where + F: FnOnce(InverseWishart) -> DMatrix, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + + for i in 0..x.nrows() { + for j in 0..x.ncols() { + assert_almost_eq!(expected[(i, j)], x[(i, j)], acc); + } + } + } + + #[test] + fn test_mean() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.333333, 0.0, + 0.0, 0.333333 + ]), + 1e-5, |w| w.mean().unwrap(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.3381, 0.0, 0.0, + 0.0, 0.0678, 0.0, + 0.0, 0.0, 0.3165, + ]), + 1e-5, |w| w.mean().unwrap(), + ); + } + + #[test] + fn test_mode() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.111111, 0.0, + 0.0, 0.111111 + ]), + 1e-5, |w| w.mode().unwrap(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.0922091, 0.0, 0.0, + 0.0, 0.0184909, 0.0, + 0.0, 0.0, 0.0863182, + ]), + 1e-5, |w| w.mode().unwrap(), + ); + } + + #[test] + fn test_variance() { + test_almost_mat( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[ + 0.222222, 0.0833333, + 0.0833333, 0.222222 + ]), + 1e-5, |w| w.variance().unwrap(), + ); + test_almost_mat( + 7.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 0.228623, 0.0171924, 0.0802565, + 0.0171924, 0.00919368, 0.016094, + 0.0802565, 0.016094, 0.200344, + ]), + 1e-5, |w| w.variance().unwrap(), + ); + } + + #[test] + fn test_pdf() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.02927491576215958, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -3.5310242469692907, + 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + + test_almost( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.0012197881567566496, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 6.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -6.709078077317236, + 1e-13, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 0.04313330476055326, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -3.1434598480128795, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 3.3174174014586637e-7, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -14.91890906187113, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + } + + #[test] + fn test_sample() { + let w = try_create(4.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ])); + + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let sample = w.sample(&mut rng); + let l_prob = w.ln_pdf(sample); + assert!(l_prob.is_finite()); + } + } +} diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 23ae7ec5..fc1e9c6a 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -34,6 +34,8 @@ pub use self::students_t::StudentsT; pub use self::triangular::Triangular; pub use self::uniform::Uniform; pub use self::weibull::Weibull; +pub use self::wishart::Wishart; +pub use self::inverse_wishart::InverseWishart; mod bernoulli; mod beta; @@ -69,6 +71,8 @@ mod uniform; mod weibull; mod ziggurat; mod ziggurat_tables; +mod wishart; +mod inverse_wishart; use crate::Result; diff --git a/src/distribution/wishart.rs b/src/distribution/wishart.rs new file mode 100644 index 00000000..ab2615cf --- /dev/null +++ b/src/distribution/wishart.rs @@ -0,0 +1,452 @@ +use std::f64::consts::LN_2; +use nalgebra::{Cholesky, DMatrix, DVector, Dynamic}; +use num_traits::Float; +use num_traits::real::Real; +use rand::Rng; +use crate::{Result, StatsError}; +use crate::consts::LN_PI; +use crate::distribution::{ChiSquared, Continuous, Normal, ziggurat}; +use crate::function::gamma::{digamma, mvgamma}; +use crate::statistics::{MeanN, Mode, VarianceN}; +use crate::function::gamma::mvlgamma; + +fn mvdigamma(p: i64, a: f64) -> f64 { + let mut sum = 0.0; + for i in 0..p { + sum += digamma(a - (i as f64) / 2.0); + } + sum +} + + +/// Implements the [Wishart distribution](http://en.wikipedia.org/wiki/Wishart_distribution) +/// +/// # Example +/// ``` +/// use nalgebra::DMatrix; +/// use statrs::distribution::{Wishart, Continuous}; +/// use statrs::statistics::Distribution; +/// use statrs::prec; +/// +/// let result = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_ok()); +/// +/// let result = Wishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); +/// assert!(result.is_err()); +/// ``` +#[derive(Debug, Clone)] +pub struct Wishart { + freedom: f64, + scale: DMatrix, + chol: Cholesky, +} + +impl Wishart { + /// Constructs a new Wishart distribution with a degrees of freedom (ν) of `freedom` + /// and a scale matrix (ψ) of `scale` + /// + /// # Errors + /// + /// Returns an error if `scale` matrix is not square. + /// Returns an error if `freedom` is `NaN`. + /// Returns an error if `freedom <= rows(scale) - 1` + /// Returns an error if `scale` is not positive definite. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let result = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_ok()); + /// + /// let result = Wishart::new(1.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// assert!(result.is_err()); + /// ``` + pub fn new(freedom: f64, scale: DMatrix) -> Result { + if scale.nrows() != scale.ncols() { + return Err(StatsError::BadParams); + } + if freedom <= 0.0 || freedom.is_nan() { + return Err(StatsError::ArgMustBePositive("degree of freedom must be positive")); + } + if freedom <= scale.nrows() as f64 - 1.0 { + return Err(StatsError::ArgGt("degree of freedom must be greater than p-1", freedom)); + } + + match Cholesky::new(scale.clone()) { + None => Err(StatsError::BadParams), + Some(chol) => { + Ok(Wishart { freedom, scale, chol }) + } + } + } + + /// Returns the degrees of freedom of + /// the Wishart distribution. + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.freedom(), 2.0); + /// ``` + pub fn freedom(&self) -> f64 { + self.freedom + } + + /// Returns the scale of the Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.scale(), &DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])); + /// ``` + pub fn scale(&self) -> &DMatrix { + &self.scale + } + + /// Returns the dimensionality of the Wishart distribution + /// + /// # Examples + /// + /// ``` + /// use nalgebra::DMatrix; + /// use statrs::distribution::Wishart; + /// + /// let w = Wishart::new(3.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])).unwrap(); + /// assert_eq!(w.p(), 2); + /// ``` + pub fn p(&self) -> usize { + self.scale.nrows() + } + + /// Returns the entropy of the Wishart distribution. + /// See [formula](https://en.wikipedia.org/wiki/Wishart_distribution#Entropy) + pub fn entropy(&self) -> Option { + let p = self.p() as f64; + Some( + (p + 1.0) / 2.0 * self.chol.determinant().ln() + + 0.5 * p * (p + 1.0) * LN_2 + + mvlgamma(p as i64, self.freedom / 2.0) + - (self.freedom - p - 1.0) / 2.0 * mvdigamma(p as i64, self.freedom / 2.0) + + self.freedom * p / 2.0 + ) + } +} + +impl MeanN> for Wishart { + /// Returns the mean of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// νψ + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix + fn mean(&self) -> Option> { + Some(self.freedom * &self.scale) + } +} + +impl VarianceN> for Wishart { + /// Returns the variance of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// Var(x_ij) = ν(ψ_ij^2 + ψ_ii ψ_jj) + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix + fn variance(&self) -> Option> { + Some(self.scale.map_with_location(|i, j, x| { + self.freedom * (self.scale[(i, i)] * self.scale[(j, j)] + x.powi(2)) + })) + } +} + +impl Mode>> for Wishart { + /// Returns the median of the Wishart distribution + /// + /// # Formula + /// + /// ```ignore + /// if k == 1 { + /// 0 + /// } else { + /// λ((k - 1) / k)^(1 / k) + /// } + /// ``` + /// + /// where `ν` is the degree of freedom, `ψ` is the scale matrix + fn mode(&self) -> Option> { + if self.freedom >= self.p() as f64 + 1.0 { + Some((self.freedom - self.p() as f64 - 1.0) * &self.scale) + } else { + None + } + } +} + +impl ::rand::distributions::Distribution> for Wishart { + fn sample(&self, rng: &mut R) -> DMatrix { + let p = self.p(); + let mut a = DMatrix::zeros(p, p); + + for i in 0..p { + a[(i, i)] = ChiSquared::new(self.freedom - i as f64).unwrap().sample(rng).sqrt(); + } + + for i in 1..p { + for j in 0..i { + a[(i, j)] = ziggurat::sample_std_normal(rng); + } + } + + let l = self.chol.l() * &a; + &l * &l.transpose() + } +} + +impl Continuous, f64> for Wishart { + /// Calculates the probability density function for the Wishart + /// distribution at `x` + fn pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let x_det = x.determinant(); + let six = self.chol.solve(&x); + + x_det.powf((self.freedom - p - 1.0) / 2.0) + * (-0.5 * six.trace()).exp() + / (2.0f64).powf(self.freedom * p / 2.0) + / self.chol.determinant().powf(self.freedom / 2.0) + / mvgamma(p as i64, self.freedom / 2.0) + } + + /// Calculates the log probability density function for the Wishart + /// distribution at `x` + fn ln_pdf(&self, x: DMatrix) -> f64 { + let p = self.p() as f64; + let x_lndet = x.determinant().ln(); + let six = self.chol.solve(&x); + + x_lndet * (self.freedom - p - 1.0) / 2.0 + - 0.5 * six.trace() + - LN_2 * (self.freedom * p / 2.0) + - self.chol.determinant().ln() * (self.freedom / 2.0) + - mvlgamma(p as i64, self.freedom / 2.0) + } +} + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use nalgebra::DMatrix; + use rand::distributions::Distribution; + use rand::rngs::StdRng; + use rand::SeedableRng; + use crate::distribution::Continuous; + use crate::distribution::wishart::Wishart; + use crate::statistics::{MeanN, Mode, VarianceN}; + + fn try_create(freedom: f64, scale: DMatrix) -> Wishart + { + let w = Wishart::new(freedom, scale); + assert!(w.is_ok()); + w.unwrap() + } + + fn test_almost( + freedom: f64, + scale: DMatrix, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(Wishart) -> f64, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_mat( + freedom: f64, + scale: DMatrix, + expected: DMatrix, + acc: f64, + eval: F, + ) where + F: FnOnce(Wishart) -> DMatrix, + { + let mvn = try_create(freedom, scale); + let x = eval(mvn); + + for i in 0..x.nrows() { + for j in 0..x.ncols() { + assert_almost_eq!(expected[(i, j)], x[(i, j)], acc); + } + } + } + + #[test] + fn test_mean() { + test_almost_mat( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[2.0, 0.0, 0.0, 2.0]), + 1e-15, |w| w.mean().unwrap(), + ); + } + + #[test] + fn test_mode() { + test_almost_mat( + 7.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[4.0, 0.0, 0.0, 4.0]), + 1e-15, |w| w.mode().unwrap(), + ); + } + + #[test] + fn test_variance() { + test_almost_mat( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + DMatrix::from_row_slice(2, 2, &[4.0, 2.0, 2.0, 4.0]), + 1e-15, |w| w.variance().unwrap(), + ); + test_almost_mat( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + DMatrix::from_row_slice(3, 3, &[ + 6.17283, 0.618926, 2.88923, + 0.618926, 0.248229, 0.579385, + 2.88923, 0.579385, 5.4093, + ]), + 1e-4, |w| w.variance().unwrap(), + ); + } + + #[test] + fn test_entropy() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 3.9538085820677584, + 1e-14, |w| w.entropy().unwrap(), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), 6.315039109555716, + 1e-12, |w| w.entropy().unwrap(), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 9.9495 + ]), 11.013723204318477, + 1e-12, |w| w.entropy().unwrap(), + ); + } + + #[test] + fn test_pdf() { + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + 0.02927491576215958, + 1e-15, |w| w.pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + test_almost( + 2.0, DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]), + -3.5310242469692907, + 1e-15, |w| w.ln_pdf(DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0])), + ); + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 0.028397507846420644, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -3.561453889551996, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 0.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 0.5627, + ])), + ); + + + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + 2.3729077174800438e-5, + 1e-12, |w| w.pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + test_almost( + 3.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ]), + -10.648809376818907, + 1e-12, |w| w.ln_pdf(DMatrix::from_row_slice(3, 3, &[ + 1.7121, 0.0000, 0.0000, + 0.0000, 0.4010, 0.0000, + 0.0000, 0.0000, 9.5627, + ])), + ); + } + + #[test] + fn test_sample() { + let w = try_create(4.0, DMatrix::from_row_slice(3, 3, &[ + 1.0143, 0.0000, 0.0000, + 0.0000, 0.2034, 0.0000, + 0.0000, 0.0000, 0.9495 + ])); + + let mut rng = StdRng::seed_from_u64(42); + + for _ in 0..100 { + let sample = w.sample(&mut rng); + let l_prob = w.ln_pdf(sample); + assert!(l_prob.is_finite()); + } + } +} diff --git a/src/function/gamma.rs b/src/function/gamma.rs index 9d5124f9..f043c2e2 100644 --- a/src/function/gamma.rs +++ b/src/function/gamma.rs @@ -421,6 +421,24 @@ fn signum(x: f64) -> f64 { } } +/// Computes the [multivariate gamma function](https://en.wikipedia.org/wiki/Multivariate_gamma_function). +pub fn mvgamma(p: i64, a: f64) -> f64 { + let mut res = std::f64::consts::PI.powf((p * (p - 1)) as f64 / 4.0); + for ii in 1..=p { + res *= gamma(a + (1 - ii) as f64 / 2.0); + } + res +} + +/// Computes the log [multivariate gamma function](https://en.wikipedia.org/wiki/Multivariate_gamma_function). +pub fn mvlgamma(p: i64, a: f64) -> f64 { + let mut res = (p * (p - 1)) as f64 * consts::LN_PI / 4.0; + for ii in 1..=p { + res += ln_gamma(a + (1 - ii) as f64 / 2.0); + } + res +} + #[rustfmt::skip] #[cfg(test)] mod tests { @@ -805,4 +823,14 @@ mod tests { assert_almost_eq!(super::inv_digamma(1.6110931485817511237336268416044190359814435699427405), 5.5, 1e-14); assert_almost_eq!(super::inv_digamma(2.2622143570941481235561593642219403924532310597356171), 10.1, 1e-13); } + + #[test] + fn test_ln_mvlgamma() { + assert!(super::mvlgamma(2, f64::NAN).is_nan()); + assert_almost_eq!(super::mvlgamma(0, 1.0), 0.0, 1e-12); + assert_almost_eq!(super::mvlgamma(2, 1.0), 1.1447298858494002, 1e-12); + assert_almost_eq!(super::mvlgamma(3, 1.5), 2.1686775340635553, 1e-12); + assert_almost_eq!(super::mvlgamma(3, 150.0 + 1.0e-12), 1794.2387481112528, 1e-12); + assert_almost_eq!(super::mvlgamma(11, 14.5), 225.2071467353132, 1e-12); + } }