Skip to content

Commit 93979b1

Browse files
henryjacYeungOnion
authored andcommitted
tests: 3d matrices test cases for pdf.
Also improves documentation for multivariate t minorly.
1 parent 05b8437 commit 93979b1

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

src/distribution/multivariate_students_t.rs

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ use rand::Rng;
99
use std::f64::consts::PI;
1010

1111
/// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution)
12-
/// distribution using the "nalgebra" crate for matrix operations
12+
/// distribution using the "nalgebra" crate for matrix operations.
1313
///
14-
/// Assumes all the marginal distributions have the same degree of freedom, ν
14+
/// Assumes all the marginal distributions have the same degree of freedom, ν.
1515
///
1616
/// # Examples
1717
///
@@ -38,7 +38,7 @@ pub struct MultivariateStudent {
3838

3939
impl MultivariateStudent {
4040
/// Constructs a new multivariate students t distribution with a location of `location`,
41-
/// scale matrix `scale` and `freedom` degrees of freedom
41+
/// scale matrix `scale` and `freedom` degrees of freedom.
4242
///
4343
/// # Errors
4444
///
@@ -91,34 +91,34 @@ impl MultivariateStudent {
9191
}
9292
}
9393

94-
/// Returns the dimension of the distribution
94+
/// Returns the dimension of the distribution.
9595
pub fn dim(&self) -> usize {
9696
self.dim
9797
}
98-
/// Returns the cholesky decomposiiton matrix of the scale matrix
98+
/// Returns the cholesky decomposiiton matrix of the scale matrix.
9999
///
100-
/// Returns A where Σ = AAᵀ
100+
/// Returns A where Σ = AAᵀ.
101101
pub fn scale_chol_decomp(&self) -> DMatrix<f64> {
102102
self.scale_chol_decomp.clone()
103103
}
104-
/// Returns the location of the distribution
104+
/// Returns the location of the distribution.
105105
pub fn location(&self) -> DVector<f64> {
106106
self.location.clone()
107107
}
108-
/// Returns the scale matrix of the distribution
108+
/// Returns the scale matrix of the distribution.
109109
pub fn scale(&self) -> DMatrix<f64> {
110110
self.scale.clone()
111111
}
112-
/// Returns the degrees of freedom of the distribution
112+
/// Returns the degrees of freedom of the distribution.
113113
pub fn freedom(&self) -> f64 {
114114
self.freedom
115115
}
116-
/// Returns the inverse of the cholesky decomposition matrix
116+
/// Returns the inverse of the cholesky decomposition matrix.
117117
pub fn precision(&self) -> DMatrix<f64> {
118118
self.precision.clone()
119119
}
120120
/// Returns the logarithmed constant part of the probability
121-
/// distribution function
121+
/// distribution function.
122122
pub fn ln_pdf_const(&self) -> f64 {
123123
self.ln_pdf_const
124124
}
@@ -129,8 +129,8 @@ impl ::rand::distributions::Distribution<DVector<f64>> for MultivariateStudent {
129129
///
130130
/// # Formula
131131
///
132-
///```ignore
133-
/// W * L * Z + μ
132+
///```math
133+
/// W L Z + μ
134134
///```
135135
///
136136
/// where `W` has √(ν/Sν) distribution, Sν has Chi-squared
@@ -164,35 +164,32 @@ impl Max<DVector<f64>> for MultivariateStudent {
164164
}
165165

166166
impl MeanN<DVector<f64>> for MultivariateStudent {
167-
/// Returns the mean of the student distribution
167+
/// Returns the mean of the student distribution.
168168
///
169169
/// # Remarks
170170
///
171171
/// This is the same mean used to construct the distribution if
172172
/// the degrees of freedom is larger than 1.
173173
fn mean(&self) -> Option<DVector<f64>> {
174174
if self.freedom > 1. {
175-
let mut vec = vec![];
176-
for elt in self.location.clone().into_iter() {
177-
vec.push(*elt);
178-
}
179-
Some(DVector::from_vec(vec))
175+
Some(self.location.clone())
180176
} else {
181177
None
182178
}
183179
}
184180
}
185181

186182
impl VarianceN<DMatrix<f64>> for MultivariateStudent {
187-
/// Returns the covariance matrix of the multivariate student distribution
183+
/// Returns the covariance matrix of the multivariate student distribution.
188184
///
189185
/// # Formula
190-
/// ```ignore
186+
///
187+
/// ```math
191188
/// Σ ⋅ ν / (ν - 2)
192189
/// ```
193190
///
194191
/// where `Σ` is the scale matrix and `ν` is the degrees of freedom.
195-
/// Only defined if freedom is larger than 2
192+
/// Only defined if freedom is larger than 2.
196193
fn variance(&self) -> Option<DMatrix<f64>> {
197194
if self.freedom > 2. {
198195
Some(self.scale.clone() * self.freedom / (self.freedom - 2.))
@@ -203,39 +200,34 @@ impl VarianceN<DMatrix<f64>> for MultivariateStudent {
203200
}
204201

205202
impl Mode<DVector<f64>> for MultivariateStudent {
206-
/// Returns the mode of the multivariate student distribution
203+
/// Returns the mode of the multivariate student distribution.
207204
///
208205
/// # Formula
209206
///
210-
/// ```ignore
207+
/// ```math
211208
/// μ
212209
/// ```
213210
///
214-
/// where `μ` is the location
211+
/// where `μ` is the location.
215212
fn mode(&self) -> DVector<f64> {
216213
self.location.clone()
217214
}
218215
}
219216

220217
impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
221-
/// Calculates the probability density function for the multivariate
222-
/// student distribution at `x`
218+
/// Calculates the probability density function for the multivariate.
219+
/// student distribution at `x`.
223220
///
224221
/// # Formula
225222
///
226-
/// ```ignore
227-
/// Gamma[(ν+p)/2] / [Gamma(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν transpose(x - μ) inv(Σ) (x - μ)]^(-(ν+p)/2)
223+
/// ```math
224+
/// (ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ) inv(Σ) (x - μ)]^(-(ν+p)/2)
228225
/// ```
229226
///
230-
/// where `ν` is the degrees of freedom, `μ` is the mean, `Gamma`
227+
/// where `ν` is the degrees of freedom, `μ` is the mean, `Γ`
231228
/// is the Gamma function, `inv(Σ)`
232229
/// is the precision matrix, `det(Σ)` is the determinant
233-
/// of the scale matrix, and `k` is the dimension of the distribution
234-
///
235-
/// TODO: Make this converge for large degrees of freedom
236-
/// Current commented code beneath fails since `MultivariateNormal::new` accepts Vec<f64> and
237-
/// not DVector or DMatrix. Should implement that instead of changing back to Vec<f64>, or
238-
/// even have a constructor `MultivariateNormal::from_student`.
230+
/// of the scale matrix, and `k` is the dimension of the distribution.
239231
fn pdf(&self, x: &'a DVector<f64>) -> f64 {
240232
if self.freedom == f64::INFINITY {
241233
let mvn = MultivariateNormal::from_students(self.clone()).unwrap();
@@ -267,7 +259,6 @@ impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
267259
}
268260
}
269261

270-
// TODO: Add more tests for other matrices than really straightforward symmetric positive
271262
#[rustfmt::skip]
272263
#[cfg(test)]
273264
mod tests {
@@ -364,18 +355,19 @@ mod tests {
364355

365356
#[test]
366357
fn test_bad_create() {
367-
// scale not symmetric
358+
// scale not symmetric.
368359
bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.);
369-
// scale not positive-definite
360+
// scale not positive-definite.
370361
bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.);
371-
// NaN in location
362+
// NaN in location.
372363
bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.);
373-
// NaN in scale Matrix
364+
// NaN in scale Matrix.
374365
bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.);
375-
// NaN in freedom
366+
// NaN in freedom.
376367
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN);
377-
// Non-positive freedom
368+
// Non-positive freedom.
378369
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.);
370+
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.);
379371
}
380372

381373
#[test]
@@ -385,6 +377,7 @@ mod tests {
385377
test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance);
386378
}
387379

380+
// Variance is only defined for freedom > 2.
388381
#[test]
389382
fn test_bad_variance() {
390383
let variance = |x: MultivariateStudent| x.variance();
@@ -405,6 +398,7 @@ mod tests {
405398
test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean);
406399
}
407400

401+
// Mean is only defined if freedom > 1.
408402
#[test]
409403
fn test_bad_mean() {
410404
let mean = |x: MultivariateStudent| x.mean();
@@ -425,9 +419,14 @@ mod tests {
425419
fn test_pdf() {
426420
let pdf = |arg: DVector<f64>| move |x: MultivariateStudent| x.pdf(&arg);
427421
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.]));
422+
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.]));
428423
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.]));
429424
test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.]));
430425
test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.]));
426+
// These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf.
427+
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134]));
428+
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185]));
429+
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.952297846610382e-16, 1e-30, pdf(dvec![0.3010, 0.1491, 0.5008]));
431430
test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.]));
432431
}
433432

@@ -447,14 +446,16 @@ mod tests {
447446
// let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.pdf(&arg);
448447
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
449448
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn);
450-
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-15, dvec![1., 1.], pdf_mvs, pdf_mvn);
449+
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
450+
// test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn);
451451
// }
452-
453452
// #[test]
454453
// fn test_ln_pdf_freedom_large() {
455454
// let pdf_mvs = |mv: MultivariateStudent, arg: DVector<f64>| mv.ln_pdf(&arg);
456455
// let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.ln_pdf(&arg);
457-
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
458-
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-50, dvec![1., 1.], pdf_mvs, pdf_mvn);
456+
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
457+
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
458+
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
459+
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
459460
// }
460461
}

0 commit comments

Comments
 (0)