diff --git a/Cargo.toml b/Cargo.toml index cd3125e0..7da6804b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,11 @@ name = "order_statistics" harness = false required-features = ["rand", "std"] +[[bench]] +name = "density" +harness = false +required-features = ["rand", "std"] + [features] default = ["std", "nalgebra", "rand"] std = ["nalgebra?/std", "rand?/std"] @@ -34,6 +39,8 @@ rand = ["dep:rand", "nalgebra?/rand"] [dependencies] approx = "0.5.0" num-traits = "0.2.14" +kdtree = "0.7.0" +thiserror = "2.0" [dependencies.rand] version = "0.9.0" diff --git a/benches/density.rs b/benches/density.rs new file mode 100644 index 00000000..4cac74b2 --- /dev/null +++ b/benches/density.rs @@ -0,0 +1,48 @@ +extern crate criterion; +extern crate rand; +extern crate statrs; +use criterion::{Criterion, criterion_group, criterion_main}; +use nalgebra::{Vector1, Vector3}; +use rand::distr::StandardUniform; + +fn generate(n_samples: usize) -> Vec +where + StandardUniform: rand::distr::Distribution, +{ + (0..n_samples).map(|_| rand::random()).collect() +} + +fn bench_density(c: &mut Criterion) { + let samples = generate(100_000); + let mut group = c.benchmark_group("density"); + group.bench_function("knn_density_1d", |b| { + b.iter(|| { + let _f = statrs::density::knn::knn_pdf(&[0.], &samples, None); + }); + }); + + let samples = generate(100_000); + group.bench_function("knn_density_3d", |b| { + b.iter(|| { + let _f = statrs::density::knn::knn_pdf(&[0., 0., 0.], &samples, None); + }); + }); + + let samples = generate(100_000); + group.bench_function("kde_density_1d", |b| { + b.iter(|| { + let _f = statrs::density::kde::kde_pdf(&Vector1::new(0.), &samples, None); + }); + }); + + let samples = generate(100_000); + group.bench_function("kde_density_3d", |b| { + b.iter(|| { + let _f = statrs::density::kde::kde_pdf(&Vector3::new(0., 0., 0.), &samples, None); + }); + }); +} + +criterion_group!(benches, bench_density); + +criterion_main!(benches); diff --git a/src/density/kde.rs b/src/density/kde.rs new file mode 100644 index 00000000..3e5d8dc1 --- /dev/null +++ b/src/density/kde.rs @@ -0,0 +1,78 @@ +use kdtree::distance::squared_euclidean; + +use crate::{ + density::{Container, DensityError, nearest_neighbors}, + function::kernel::{Gaussian, Kernel}, +}; + +/// Computes the kernel density estimate for a given point `x` +/// using the samples provided and a specified kernel. +/// +/// The optimal `k` is computed using [Orava's][orava] formula when `bandwidth` is `None`. +/// +/// orava: K-nearest neighbour kernel density estimation, the choice of optimal k; Jan Orava 2012. +pub fn kde_pdf(x: &X, samples: &S, bandwidth: Option) -> Result +where + S: AsRef<[X]> + Container, + X: AsRef<[f64]> + Container + PartialEq, +{ + let n_samples = samples.length() as f64; + let neighbors = nearest_neighbors(x, samples, bandwidth)?.0; + if neighbors.is_empty() { + Err(DensityError::EmptyNeighborhood) + } else { + let radius = neighbors.last().unwrap().sqrt(); // safe to unwrap here since `neighbors` is not empty + let d = x.length() as i32; + Ok((1. / (n_samples * radius.powi(d))) + * samples + .as_ref() + .iter() + .map(|xi| { + Gaussian.evaluate(squared_euclidean(x.as_ref(), xi.as_ref()).sqrt() / radius) + / crate::consts::SQRT_2PI.powi(d - 1) + }) + .sum::()) + } +} + +#[cfg(test)] +mod tests { + use core::f32::consts::PI; + + use super::*; + use crate::distribution::Normal; + use crate::function::kernel::Kernel; + use nalgebra::{Vector1, Vector2}; + use rand::distr::Distribution; + + #[test] + fn test_kde_pdf() { + let law = Normal::new(0., 1.).unwrap(); + let mut rng = rand::rng(); + let gaussian = crate::function::kernel::Gaussian; + let samples_1d = (0..100000) + .map(|_| Vector1::new(law.sample(&mut rng))) + .collect::>(); + let x = Vector1::new(0.); + let kde_density_with_bandwidth = kde_pdf(&x, &samples_1d, Some(0.05)); + let kde_density = kde_pdf(&x, &samples_1d, None); + let reference_value = gaussian.evaluate(0.); + assert!(kde_density.is_ok()); + assert!(kde_density_with_bandwidth.is_ok()); + assert!((kde_density.unwrap() - reference_value).abs() < 2e-2); + assert!((kde_density_with_bandwidth.unwrap() - reference_value).abs() < 3e-2); + + let samples_2d = (0..100000) + .map(|_| Vector2::new(law.sample(&mut rng), law.sample(&mut rng))) + .collect::>(); + + let x = Vector2::new(0., 0.); + let kde_density_with_bandwidth = kde_pdf(&x, &samples_2d, Some(0.05)); + let kde_density = kde_pdf(&x, &samples_2d, None); + let reference_value = 1. / (2. * PI) as f64; + assert!(kde_density.is_ok()); + assert!(kde_density_with_bandwidth.is_ok()); + assert!((kde_density.unwrap() - reference_value).abs() < 2e-2); + assert!((kde_density_with_bandwidth.unwrap() - reference_value).abs() < 3e-2); + } +} diff --git a/src/density/knn.rs b/src/density/knn.rs new file mode 100644 index 00000000..28dbfe2b --- /dev/null +++ b/src/density/knn.rs @@ -0,0 +1,78 @@ +use super::Container; +use crate::{ + density::{DensityError, nearest_neighbors}, + function::gamma::gamma, +}; +use core::f64::consts::PI; + +/// Computes the `k`-nearest neighbor density estimate for a given point `x` +/// using the samples provided. +/// +/// The optimal `k` is computed using [Orava's][orava] formula when `bandwidth` is `None`. +/// +/// orava: K-nearest neighbour kernel density estimation, the choice of optimal k; Jan Orava 2012. +pub fn knn_pdf(x: &X, samples: &S, bandwidth: Option) -> Result +where + S: AsRef<[X]> + Container, + X: AsRef<[f64]> + Container + PartialEq, +{ + let n_samples = samples.length() as f64; + let (neighbors, k) = nearest_neighbors(x, samples, bandwidth)?; + if neighbors.is_empty() { + Err(DensityError::EmptyNeighborhood) + } else { + let radius = neighbors.last().unwrap().sqrt(); + let d = x.length() as f64; + Ok((k / n_samples) * (gamma(d / 2. + 1.) / (PI.powf(d / 2.) * radius.powf(d)))) + } +} + +#[cfg(test)] +mod tests { + use core::f32::consts::PI; + + use super::*; + use crate::distribution::Normal; + use crate::function::kernel::Kernel; + use nalgebra::{Vector1, Vector2}; + use rand::distr::Distribution; + + #[test] + fn test_knn_pdf() { + let law = Normal::new(0., 1.).unwrap(); + let mut rng = rand::rng(); + let gaussian = crate::function::kernel::Gaussian; + let samples_1d = (0..100000) + .map(|_| Vector1::new(law.sample(&mut rng))) + .collect::>(); + let x = Vector1::new(0.); + let knn_density_with_bandwidth = knn_pdf(&x, &samples_1d, Some(0.05)); + let knn_density = knn_pdf(&x, &samples_1d, None); + let reference_value = gaussian.evaluate(0.); + assert!(knn_density.is_ok()); + assert!(knn_density_with_bandwidth.is_ok()); + assert!((knn_density.unwrap() - reference_value).abs() < 2e-2); + assert!((knn_density_with_bandwidth.unwrap() - reference_value).abs() < 3e-2); + + let samples_2d = (0..100000) + .map(|_| Vector2::new(law.sample(&mut rng), law.sample(&mut rng))) + .collect::>(); + + let x = Vector2::new(0., 0.); + let knn_density_with_bandwidth = knn_pdf(&x, &samples_2d, Some(0.05)); + let knn_density = knn_pdf(&x, &samples_2d, None); + let reference_value = 1. / (2. * PI) as f64; + assert!(knn_density.is_ok()); + assert!(knn_density_with_bandwidth.is_ok()); + assert!((knn_density.unwrap() - reference_value).abs() < 2e-2); + assert!((knn_density_with_bandwidth.unwrap() - reference_value).abs() < 3e-2); + } + + #[test] + fn test_knn_pdf_empty_samples() { + let samples: Vec<[f64; 1]> = vec![]; + let x = 3.0; + let result = knn_pdf(&[x], &samples, None); + assert!(matches!(result, Err(DensityError::EmptySample))); + } +} diff --git a/src/density/mod.rs b/src/density/mod.rs new file mode 100644 index 00000000..6e35af38 --- /dev/null +++ b/src/density/mod.rs @@ -0,0 +1,109 @@ +pub mod kde; +pub mod knn; +use kdtree::{ErrorKind, KdTree, distance::squared_euclidean}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum DensityError { + /// Error when the k-d tree cannot be built or queried. + #[error(transparent)] + KdTree(#[from] ErrorKind), + EmptySample, + EmptyNeighborhood, +} + +impl core::fmt::Display for DensityError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + DensityError::KdTree(err) => write!(f, "K-d tree error: {}", err), + DensityError::EmptySample => write!(f, "No samples provided"), + DensityError::EmptyNeighborhood => write!(f, "No neighbors found"), + } + } +} + +fn orava_optimal_k(n_samples: f64) -> f64 { + // Adapted from K-nearest neighbour kernel density estimation, the choice of optimal k; Jan Orava 2012 + (0.587 * n_samples.powf(4.0 / 5.0)).round().max(1.) +} + +/// Handles variable/point types for which nearest neighbors can be computed. +pub trait Container: Clone { + type Elem; + fn length(&self) -> usize; +} + +macro_rules! impl_container { + ($($t:ty),*) => { + $( + impl Container for $t { + type Elem = T; + fn length(&self) -> usize { + self.len() + } + + } + )* + }; +} +impl_container!( + [T; 1], + [T; 2], + [T; 3], + Vec, + nalgebra::Vector1, + nalgebra::Vector2, + nalgebra::Vector3, + nalgebra::Vector4, + nalgebra::Vector5, + nalgebra::Vector6 +); +pub type NearestNeighbors = (Vec, f64); + +pub(crate) fn nearest_neighbors( + x: &X, + samples: &S, + bandwidth: Option, +) -> Result +where + S: AsRef<[X]> + Container, + X: AsRef<[f64]> + Container + PartialEq, +{ + if samples.length() == 0 { + return Err(DensityError::EmptySample); + } + let n_samples = samples.length() as f64; + let d = x.length(); + let mut tree = KdTree::with_capacity(d, 2usize.pow(n_samples.log2() as u32)); + for (position, sample) in samples.as_ref().iter().enumerate() { + tree.add(sample.clone(), position)?; + } + if let Some(bandwidth) = bandwidth { + let neighbors = tree.within(x.as_ref(), bandwidth, &squared_euclidean)?; + let k = neighbors.len() as f64; + Ok((neighbors.into_iter().map(|r| r.0).collect(), k)) + } else { + let k = orava_optimal_k(n_samples); + Ok(( + tree.nearest(x.as_ref(), k as usize, &squared_euclidean)? + .into_iter() + .map(|r| r.0) + .collect(), + k, + )) + } +} +#[cfg(test)] +mod tests { + use nalgebra::Vector3; + + use super::*; + + #[test] + fn test_vec_container() { + let v1 = vec![1.0, 2.0, 3.0]; + assert_eq!(v1.length(), 3); + let v2 = Vector3::new(1.0, 2.0, 3.0); + assert_eq!(v2.length(), 3); + } +} diff --git a/src/lib.rs b/src/lib.rs index e72d0536..61fa30fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,6 +54,7 @@ #![cfg_attr(not(feature = "std"), no_std)] pub mod consts; +pub mod density; pub mod distribution; pub mod euclid; pub mod function;