From f50b0e1aa4717b0de98cb55bc85313552ceacf87 Mon Sep 17 00:00:00 2001 From: oOOo-YKS <18964484242@163.com> Date: Wed, 19 Mar 2025 21:13:27 +0800 Subject: [PATCH 1/2] fix(distribution): enforce finite params in Gamma::new - before: - Gamma::new raises an error oly when both shape and rate are zero - after: - Add infinite checks for shape/rate parameters - Add test cases for invalid infinite params Closes #103 Refs: #98, #102 --- src/distribution/gamma.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index a403355d..88abb3de 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -67,22 +67,27 @@ impl Gamma { /// let mut result = Gamma::new(3.0, 1.0); /// assert!(result.is_ok()); /// - /// result = Gamma::new(0.0, 0.0); + /// result = Gamma::new(1.0, 0.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(0.0, 1.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(f64::INFINITY, 1.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(1.0, f64::INFINITY); /// assert!(result.is_err()); /// ``` pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() || shape <= 0.0 { + if shape.is_nan() || shape <= 0.0 || shape.is_infinite() { return Err(GammaError::ShapeInvalid); } - if rate.is_nan() || rate <= 0.0 { + if rate.is_nan() || rate <= 0.0 || rate.is_infinite() { return Err(GammaError::RateInvalid); } - if shape.is_infinite() && rate.is_infinite() { - return Err(GammaError::ShapeAndRateInfinite); - } - Ok(Gamma { shape, rate }) } From 1a8e2fcd52f6eb390c9c32cc1c0910a9a3cca31e Mon Sep 17 00:00:00 2001 From: oOOo-YKS <18964484242@163.com> Date: Sat, 5 Apr 2025 19:36:59 +0800 Subject: [PATCH 2/2] fix(gamma): enforce finite params and cleanup validation 1. Deprecate GammaError::ShapeAndRateInfinite variant (no longer used) 2. Update error messages to explicitly mention NaN/infinite/zero/negative values: - "Shape must be finite (not NaN, infinite, zero, or negative)" - "Rate must be finite (not NaN, infinite, zero, or negative)" 3. Remove redundant infinity checks from cdf/sf methods (validation now centralized in constructor) 4. Related to #103, closes #98 --- src/distribution/gamma.rs | 38 ++++++++++++-------------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 88abb3de..1985b90c 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -32,18 +32,20 @@ pub enum GammaError { /// The rate is NaN, zero or less than zero. RateInvalid, - - /// The shape and rate are both infinite. - ShapeAndRateInfinite, } impl std::fmt::Display for GammaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), - GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), - GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), + GammaError::ShapeInvalid => write!( + f, + "Shape must be finite (not NaN, infinite, zero, or negative)." + ), + GammaError::RateInvalid => write!( + f, + "Rate must be finite (not NaN, infinite, zero, or negative)." + ), } } } @@ -69,13 +71,13 @@ impl Gamma { /// /// result = Gamma::new(1.0, 0.0); /// assert!(result.is_err()); - /// + /// /// result = Gamma::new(0.0, 1.0); /// assert!(result.is_err()); - /// + /// /// result = Gamma::new(f64::INFINITY, 1.0); /// assert!(result.is_err()); - /// + /// /// result = Gamma::new(1.0, f64::INFINITY); /// assert!(result.is_err()); /// ``` @@ -150,12 +152,6 @@ impl ContinuousCDF for Gamma { fn cdf(&self, x: f64) -> f64 { if x <= 0.0 { 0.0 - } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { - 1.0 - } else if self.rate.is_infinite() { - 0.0 - } else if x.is_infinite() { - 1.0 } else { gamma::gamma_lr(self.shape, x * self.rate) } @@ -175,12 +171,6 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { - 0.0 - } else if self.rate.is_infinite() { - 1.0 - } else if x.is_infinite() { - 0.0 } else { gamma::gamma_ur(self.shape, x * self.rate) } @@ -470,11 +460,7 @@ mod tests { (-1.0, 1.0, GammaError::ShapeInvalid), (-1.0, -1.0, GammaError::ShapeInvalid), (-1.0, f64::NAN, GammaError::ShapeInvalid), - ( - f64::INFINITY, - f64::INFINITY, - GammaError::ShapeAndRateInfinite, - ), + (f64::INFINITY, f64::INFINITY, GammaError::ShapeInvalid), ]; for (s, r, err) in invalid { test_create_err(s, r, err);