Skip to content
2 changes: 2 additions & 0 deletions src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub use self::laplace::Laplace;
pub use self::log_normal::LogNormal;
pub use self::multinomial::Multinomial;
pub use self::multivariate_normal::MultivariateNormal;
pub use self::multivariate_students_t::MultivariateStudent;
pub use self::negative_binomial::NegativeBinomial;
pub use self::normal::Normal;
pub use self::pareto::Pareto;
Expand Down Expand Up @@ -60,6 +61,7 @@ mod laplace;
mod log_normal;
mod multinomial;
mod multivariate_normal;
mod multivariate_students_t;
mod negative_binomial;
mod normal;
mod pareto;
Expand Down
28 changes: 27 additions & 1 deletion src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::distribution::Continuous;
use crate::distribution::Normal;
use crate::distribution::{MultivariateStudent, Normal};
use crate::statistics::{Max, MeanN, Min, Mode, VarianceN};
use crate::{Result, StatsError};
use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector};
Expand Down Expand Up @@ -49,6 +49,25 @@ impl MultivariateNormal<Dyn> {
let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
MultivariateNormal::new_from_nalgebra(mean, cov)
}

/// Constructs a new multivariate normal distribution from a
/// multivariate students t distribution, which have equal variables
/// when `mvs.freedom == f64::INFINITY`
pub fn from_students(mvs: MultivariateStudent) -> Result<Self> {
let mu = mvs.location();
let scale = mvs.scale();
let cov_det = scale.determinant();
let pdf_const = ((2. * PI).powi(mu.nrows() as i32) * cov_det.abs())
.recip()
.sqrt();
Ok(MultivariateNormal {
cov_chol_decomp: mvs.scale_chol_decomp(),
mu: mvs.location(),
cov: mvs.scale(),
precision: mvs.precision(),
pdf_const,
})
}
}

impl<D> MultivariateNormal<D>
Expand Down Expand Up @@ -609,4 +628,11 @@ mod tests {
ln_pdf(dvector![100., 100.]),
);
}

#[test]
#[should_panic]
fn test_pdf_mismatched_arg_size() {
let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap();
mvn.pdf(vec![1.]); // x.size != mu.size
}
}
Loading