Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 48 additions & 108 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::f64;
#[derive(Clone, PartialEq, Debug)]
pub struct Categorical {
norm_pmf: Vec<f64>,
cdf: Vec<f64>,
sf: Vec<f64>,
norm_cdf: Vec<f64>,
norm_sf: Vec<f64>,
}

/// Represents the errors that can occur when creating a [`Categorical`].
Expand Down Expand Up @@ -98,22 +98,25 @@ impl Categorical {
return Err(CategoricalError::ProbMassSumZero);
}

// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
}
let mut cdf_sum = 0.0;

let mut norm_cdf = Vec::with_capacity(prob_mass.len());
let mut norm_sf = Vec::with_capacity(prob_mass.len());
let mut norm_pmf = Vec::with_capacity(prob_mass.len());

fn cdf_max(&self) -> f64 {
*self.cdf.last().unwrap()
for &prob in prob_mass {
cdf_sum += prob;

norm_cdf.push(cdf_sum / prob_sum);
norm_sf.push((prob_sum - cdf_sum) / prob_sum);
norm_pmf.push(prob / prob_sum);
}

Ok(Categorical {
norm_pmf,
norm_cdf,
norm_sf,
})
}
}

Expand All @@ -123,27 +126,31 @@ impl std::fmt::Display for Categorical {
}
}

#[cfg(feature = "rand")]
use rand::distributions::Distribution as RandDistribution;

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<usize> for Categorical {
impl RandDistribution<usize> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
sample_unchecked(rng, &self.cdf)
let draw = rng.gen::<f64>();
self.norm_cdf.iter().position(|val| *val >= draw).unwrap()
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Categorical {
impl RandDistribution<u64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, &self.cdf) as u64
<Self as RandDistribution<usize>>::sample(&self, rng) as u64
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Categorical {
impl RandDistribution<f64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
sample_unchecked(rng, &self.cdf) as f64
<Self as RandDistribution<usize>>::sample(&self, rng) as f64
}
}

Expand All @@ -159,11 +166,7 @@ impl DiscreteCDF<u64, f64> for Categorical {
///
/// where `p_j` is the probability mass for the `j`th category
fn cdf(&self, x: u64) -> f64 {
if x >= self.cdf.len() as u64 {
1.0
} else {
self.cdf.get(x as usize).unwrap() / self.cdf_max()
}
*self.norm_cdf.get(x as usize).unwrap_or(&1.0)
}

/// Calculates the survival function for the categorical distribution
Expand All @@ -175,11 +178,7 @@ impl DiscreteCDF<u64, f64> for Categorical {
/// [ sum(p_j) from x..end ]
/// ```
fn sf(&self, x: u64) -> f64 {
if x >= self.sf.len() as u64 {
0.0
} else {
self.sf.get(x as usize).unwrap() / self.cdf_max()
}
*self.norm_sf.get(x as usize).unwrap_or(&0.0)
}

/// Calculates the inverse cumulative distribution function for the
Expand All @@ -203,8 +202,17 @@ impl DiscreteCDF<u64, f64> for Categorical {
if x >= 1.0 || x <= 0.0 {
panic!("x must be in [0, 1]")
}
let denorm_prob = x * self.cdf_max();
binary_index(&self.cdf, denorm_prob) as u64

// `Vec::binary_search` will either return the index of a value equal to x
// or an index where x could be inserted into the sorted Vec.
// Both fit the description, so return either one.
match self
.norm_cdf
.binary_search_by(|v| v.partial_cmp(&x).unwrap())
{
Ok(idx) => idx as u64,
Err(idx) => idx as u64,
}
}
}

Expand Down Expand Up @@ -234,7 +242,7 @@ impl Max<u64> for Categorical {
/// n
/// ```
fn max(&self) -> u64 {
self.cdf.len() as u64 - 1
self.norm_cdf.len() as u64 - 1
}
}

Expand Down Expand Up @@ -337,74 +345,6 @@ impl Discrete<u64, f64> for Categorical {
}
}

/// Draws a sample from the categorical distribution described by `cdf`
/// without doing any bounds checking
#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {
let draw = rng.gen::<f64>() * cdf.last().unwrap();
cdf.iter().position(|val| *val >= draw).unwrap()
}

/// Computes the cdf from the given probability masses. Performs
/// no parameter or bounds checking.
pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
let mut cdf = Vec::with_capacity(prob_mass.len());
prob_mass.iter().fold(0.0, |s, p| {
let sum = s + p;
cdf.push(sum);
sum
});
cdf
}

/// Computes the sf from the given cumulative densities.
/// Performs no parameter or bounds checking.
pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
let max = *cdf.last().unwrap();
cdf.iter().map(|x| max - x).collect()
}

// Returns the index of val if placed into the sorted search array.
// If val is greater than all elements, it therefore would return
// the length of the array (N). If val is less than all elements, it would
// return 0. Otherwise val returns the index of the first element larger than
// it within the search array.
fn binary_index(search: &[f64], val: f64) -> usize {
use std::cmp;

let mut low = 0_isize;
let mut high = search.len() as isize - 1;
while low <= high {
let mid = low + ((high - low) / 2);
let el = *search.get(mid as usize).unwrap();
if el > val {
high = mid - 1;
} else if el < val {
low = mid.saturating_add(1);
} else {
return mid as usize;
}
}
cmp::min(search.len(), cmp::max(low, 0) as usize)
}

#[test]
fn test_prob_mass_to_cdf() {
let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
let res = prob_mass_to_cdf(&arr);
assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
}

#[test]
fn test_binary_index() {
let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
assert_eq!(0, binary_index(&arr, -1.0));
assert_eq!(2, binary_index(&arr, 5.0));
assert_eq!(3, binary_index(&arr, 5.2));
assert_eq!(5, binary_index(&arr, 10.1));
}

#[rustfmt::skip]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -541,10 +481,10 @@ mod tests {
fn test_cdf_sf_mirror() {
let mass = [4.0, 2.5, 2.5, 1.0];
let cat = Categorical::new(&mass).unwrap();
assert_eq!(cat.cdf(0), 1.-cat.sf(0));
assert_eq!(cat.cdf(1), 1.-cat.sf(1));
assert_eq!(cat.cdf(2), 1.-cat.sf(2));
assert_eq!(cat.cdf(3), 1.-cat.sf(3));
assert_eq!(cat.cdf(0), 1. - cat.sf(0));
assert_eq!(cat.cdf(1), 1. - cat.sf(1));
assert_eq!(cat.cdf(2), 1. - cat.sf(2));
assert_eq!(cat.cdf(3), 1. - cat.sf(3));
}

#[test]
Expand Down