From 37f8fdf2a81a7e6bf8b07f31aad0f3c1ed319ca3 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Tue, 24 Sep 2024 00:06:14 +0200 Subject: [PATCH] refactor: Categorical now stores normalized values norm_pmf (probabilities) was already normalized before storing, but cdf and sf weren't. Instead, they were normalized on every API call. The refactor also reduces the amount of vec/slice iterations in `new` from 4 to 2. --- src/distribution/categorical.rs | 156 ++++++++++---------------------- 1 file changed, 48 insertions(+), 108 deletions(-) diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index 008e89f9..c2e8e213 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -21,8 +21,8 @@ use std::f64; #[derive(Clone, PartialEq, Debug)] pub struct Categorical { norm_pmf: Vec, - cdf: Vec, - sf: Vec, + norm_cdf: Vec, + norm_sf: Vec, } /// Represents the errors that can occur when creating a [`Categorical`]. @@ -98,22 +98,25 @@ impl Categorical { return Err(CategoricalError::ProbMassSumZero); } - // extract un-normalized cdf - let cdf = prob_mass_to_cdf(prob_mass); - // extract un-normalized sf - let sf = cdf_to_sf(&cdf); - // extract normalized probability mass - let sum = cdf[cdf.len() - 1]; - let mut norm_pmf = vec![0.0; prob_mass.len()]; - norm_pmf - .iter_mut() - .zip(prob_mass.iter()) - .for_each(|(np, pm)| *np = *pm / sum); - Ok(Categorical { norm_pmf, cdf, sf }) - } + let mut cdf_sum = 0.0; + + let mut norm_cdf = Vec::with_capacity(prob_mass.len()); + let mut norm_sf = Vec::with_capacity(prob_mass.len()); + let mut norm_pmf = Vec::with_capacity(prob_mass.len()); - fn cdf_max(&self) -> f64 { - *self.cdf.last().unwrap() + for &prob in prob_mass { + cdf_sum += prob; + + norm_cdf.push(cdf_sum / prob_sum); + norm_sf.push((prob_sum - cdf_sum) / prob_sum); + norm_pmf.push(prob / prob_sum); + } + + Ok(Categorical { + norm_pmf, + norm_cdf, + norm_sf, + }) } } @@ -123,27 +126,31 @@ impl std::fmt::Display for Categorical { } } +#[cfg(feature = "rand")] +use rand::distributions::Distribution as RandDistribution; + #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Categorical { +impl RandDistribution for Categorical { fn sample(&self, rng: &mut R) -> usize { - sample_unchecked(rng, &self.cdf) + let draw = rng.gen::(); + self.norm_cdf.iter().position(|val| *val >= draw).unwrap() } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Categorical { +impl RandDistribution for Categorical { fn sample(&self, rng: &mut R) -> u64 { - sample_unchecked(rng, &self.cdf) as u64 + >::sample(&self, rng) as u64 } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -impl ::rand::distributions::Distribution for Categorical { +impl RandDistribution for Categorical { fn sample(&self, rng: &mut R) -> f64 { - sample_unchecked(rng, &self.cdf) as f64 + >::sample(&self, rng) as f64 } } @@ -159,11 +166,7 @@ impl DiscreteCDF for Categorical { /// /// where `p_j` is the probability mass for the `j`th category fn cdf(&self, x: u64) -> f64 { - if x >= self.cdf.len() as u64 { - 1.0 - } else { - self.cdf.get(x as usize).unwrap() / self.cdf_max() - } + *self.norm_cdf.get(x as usize).unwrap_or(&1.0) } /// Calculates the survival function for the categorical distribution @@ -175,11 +178,7 @@ impl DiscreteCDF for Categorical { /// [ sum(p_j) from x..end ] /// ``` fn sf(&self, x: u64) -> f64 { - if x >= self.sf.len() as u64 { - 0.0 - } else { - self.sf.get(x as usize).unwrap() / self.cdf_max() - } + *self.norm_sf.get(x as usize).unwrap_or(&0.0) } /// Calculates the inverse cumulative distribution function for the @@ -203,8 +202,17 @@ impl DiscreteCDF for Categorical { if x >= 1.0 || x <= 0.0 { panic!("x must be in [0, 1]") } - let denorm_prob = x * self.cdf_max(); - binary_index(&self.cdf, denorm_prob) as u64 + + // `Vec::binary_search` will either return the index of a value equal to x + // or an index where x could be inserted into the sorted Vec. + // Both fit the description, so return either one. + match self + .norm_cdf + .binary_search_by(|v| v.partial_cmp(&x).unwrap()) + { + Ok(idx) => idx as u64, + Err(idx) => idx as u64, + } } } @@ -234,7 +242,7 @@ impl Max for Categorical { /// n /// ``` fn max(&self) -> u64 { - self.cdf.len() as u64 - 1 + self.norm_cdf.len() as u64 - 1 } } @@ -337,74 +345,6 @@ impl Discrete for Categorical { } } -/// Draws a sample from the categorical distribution described by `cdf` -/// without doing any bounds checking -#[cfg(feature = "rand")] -#[cfg_attr(docsrs, doc(cfg(feature = "rand")))] -pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> usize { - let draw = rng.gen::() * cdf.last().unwrap(); - cdf.iter().position(|val| *val >= draw).unwrap() -} - -/// Computes the cdf from the given probability masses. Performs -/// no parameter or bounds checking. -pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { - let mut cdf = Vec::with_capacity(prob_mass.len()); - prob_mass.iter().fold(0.0, |s, p| { - let sum = s + p; - cdf.push(sum); - sum - }); - cdf -} - -/// Computes the sf from the given cumulative densities. -/// Performs no parameter or bounds checking. -pub fn cdf_to_sf(cdf: &[f64]) -> Vec { - let max = *cdf.last().unwrap(); - cdf.iter().map(|x| max - x).collect() -} - -// Returns the index of val if placed into the sorted search array. -// If val is greater than all elements, it therefore would return -// the length of the array (N). If val is less than all elements, it would -// return 0. Otherwise val returns the index of the first element larger than -// it within the search array. -fn binary_index(search: &[f64], val: f64) -> usize { - use std::cmp; - - let mut low = 0_isize; - let mut high = search.len() as isize - 1; - while low <= high { - let mid = low + ((high - low) / 2); - let el = *search.get(mid as usize).unwrap(); - if el > val { - high = mid - 1; - } else if el < val { - low = mid.saturating_add(1); - } else { - return mid as usize; - } - } - cmp::min(search.len(), cmp::max(low, 0) as usize) -} - -#[test] -fn test_prob_mass_to_cdf() { - let arr = [0.0, 0.5, 0.5, 3.0, 1.1]; - let res = prob_mass_to_cdf(&arr); - assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]); -} - -#[test] -fn test_binary_index() { - let arr = [0.0, 3.0, 5.0, 9.0, 10.0]; - assert_eq!(0, binary_index(&arr, -1.0)); - assert_eq!(2, binary_index(&arr, 5.0)); - assert_eq!(3, binary_index(&arr, 5.2)); - assert_eq!(5, binary_index(&arr, 10.1)); -} - #[rustfmt::skip] #[cfg(test)] mod tests { @@ -541,10 +481,10 @@ mod tests { fn test_cdf_sf_mirror() { let mass = [4.0, 2.5, 2.5, 1.0]; let cat = Categorical::new(&mass).unwrap(); - assert_eq!(cat.cdf(0), 1.-cat.sf(0)); - assert_eq!(cat.cdf(1), 1.-cat.sf(1)); - assert_eq!(cat.cdf(2), 1.-cat.sf(2)); - assert_eq!(cat.cdf(3), 1.-cat.sf(3)); + assert_eq!(cat.cdf(0), 1. - cat.sf(0)); + assert_eq!(cat.cdf(1), 1. - cat.sf(1)); + assert_eq!(cat.cdf(2), 1. - cat.sf(2)); + assert_eq!(cat.cdf(3), 1. - cat.sf(3)); } #[test]