diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 94c35623..3122e5f0 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -33,18 +33,20 @@ pub enum GammaError { /// The rate is NaN, zero or less than zero. RateInvalid, - - /// The shape and rate are both infinite. - ShapeAndRateInfinite, } impl core::fmt::Display for GammaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { - GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), - GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), - GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), + GammaError::ShapeInvalid => write!( + f, + "Shape must be finite (not NaN, infinite, zero, or negative)." + ), + GammaError::RateInvalid => write!( + f, + "Rate must be finite (not NaN, infinite, zero, or negative)." + ), } } } @@ -69,22 +71,27 @@ impl Gamma { /// let mut result = Gamma::new(3.0, 1.0); /// assert!(result.is_ok()); /// - /// result = Gamma::new(0.0, 0.0); + /// result = Gamma::new(1.0, 0.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(0.0, 1.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(f64::INFINITY, 1.0); + /// assert!(result.is_err()); + /// + /// result = Gamma::new(1.0, f64::INFINITY); /// assert!(result.is_err()); /// ``` pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() || shape <= 0.0 { + if shape.is_nan() || shape <= 0.0 || shape.is_infinite() { return Err(GammaError::ShapeInvalid); } - if rate.is_nan() || rate <= 0.0 { + if rate.is_nan() || rate <= 0.0 || rate.is_infinite() { return Err(GammaError::RateInvalid); } - if shape.is_infinite() && rate.is_infinite() { - return Err(GammaError::ShapeAndRateInfinite); - } - Ok(Gamma { shape, rate }) } @@ -147,12 +154,6 @@ impl ContinuousCDF for Gamma { fn cdf(&self, x: f64) -> f64 { if x <= 0.0 { 0.0 - } else if prec::ulps_eq!(x, self.shape) && self.rate.is_infinite() { - 1.0 - } else if self.rate.is_infinite() { - 0.0 - } else if x.is_infinite() { - 1.0 } else { gamma::gamma_lr(self.shape, x * self.rate) } @@ -172,12 +173,6 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } else if prec::ulps_eq!(x, self.shape) && self.rate.is_infinite() { - 0.0 - } else if self.rate.is_infinite() { - 1.0 - } else if x.is_infinite() { - 0.0 } else { gamma::gamma_ur(self.shape, x * self.rate) } @@ -468,11 +463,7 @@ mod tests { (-1.0, 1.0, GammaError::ShapeInvalid), (-1.0, -1.0, GammaError::ShapeInvalid), (-1.0, f64::NAN, GammaError::ShapeInvalid), - ( - f64::INFINITY, - f64::INFINITY, - GammaError::ShapeAndRateInfinite, - ), + (f64::INFINITY, f64::INFINITY, GammaError::ShapeInvalid), ]; for (s, r, err) in invalid { test_create_err(s, r, err);