Skip to content

Commit cfb3505

Browse files
authored
Merge pull request RustPython#3256 from aDotInTheVoid/stats
Add `_statistics` module containing `_normal_dist_inv_cd`
2 parents 047bab9 + 639a4fe commit cfb3505

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

stdlib/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ mod platform;
2020
mod pyexpat;
2121
mod pystruct;
2222
mod random;
23+
mod statistics;
2324
// TODO: maybe make this an extension module, if we ever get those
2425
// mod re;
2526
#[cfg(not(target_arch = "wasm32"))]
2627
pub mod socket;
27-
mod statistics;
2828
#[cfg(unix)]
2929
mod syslog;
3030
mod unicodedata;
@@ -93,6 +93,7 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
9393
"_struct" => pystruct::make_module,
9494
"unicodedata" => unicodedata::make_module,
9595
"zlib" => zlib::make_module,
96+
"_statistics" => statistics::make_module,
9697
// crate::vm::sysmodule::sysconfigdata_name() => sysconfigdata::make_module,
9798
}
9899
// parser related modules:

stdlib/src/statistics.rs

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,21 @@
1-
pub(crate) use statistics::make_module;
1+
pub(crate) use _statistics::make_module;
22

3-
#[pymodule(name = "_statistics")]
4-
mod statistics {
5-
use rustpython_vm::{PyResult, VirtualMachine};
3+
#[pymodule]
4+
mod _statistics {
5+
use crate::vm::{function::ArgIntoFloat, PyResult, VirtualMachine};
66

7-
/*
8-
* There is no closed-form solution to the inverse CDF for the normal
9-
* distribution, so we use a rational approximation instead:
10-
* Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
11-
* Normal Distribution". Applied Statistics. Blackwell Publishing. 37
12-
* (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
13-
*/
14-
15-
#[pyfunction(name = "_normal_dist_inv_cdf")]
16-
fn normal_dist_inv_cdf(p: f64, mu: f64, sigma: f64, vm: &VirtualMachine) -> PyResult<f64> {
7+
// See https://github.com/python/cpython/blob/6846d6712a0894f8e1a91716c11dd79f42864216/Modules/_statisticsmodule.c#L28-L120
8+
#[allow(clippy::excessive_precision)]
9+
fn normal_dist_inv_cdf(p: f64, mu: f64, sigma: f64) -> Option<f64> {
1710
if p <= 0.0 || p >= 1.0 || sigma <= 0.0 {
18-
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
11+
return None;
1912
}
2013

2114
let q = p - 0.5;
22-
let num: f64;
23-
let den: f64;
24-
#[allow(clippy::excessive_precision)]
2515
if q.abs() <= 0.425 {
2616
let r = 0.180625 - q * q;
2717
// Hash sum-55.8831928806149014439
28-
num = (((((((2.5090809287301226727e+3 * r + 3.3430575583588128105e+4) * r
18+
let num = (((((((2.5090809287301226727e+3 * r + 3.3430575583588128105e+4) * r
2919
+ 6.7265770927008700853e+4)
3020
* r
3121
+ 4.5921953931549871457e+4)
@@ -38,7 +28,7 @@ mod statistics {
3828
* r
3929
+ 3.3871328727963666080e+0)
4030
* q;
41-
den = ((((((5.2264952788528545610e+3 * r + 2.8729085735721942674e+4) * r
31+
let den = ((((((5.2264952788528545610e+3 * r + 2.8729085735721942674e+4) * r
4232
+ 3.9307895800092710610e+4)
4333
* r
4434
+ 2.1213794301586595867e+4)
@@ -51,18 +41,18 @@ mod statistics {
5141
* r
5242
+ 1.0;
5343
if den == 0.0 {
54-
return Err(
55-
vm.new_value_error("inv_cdf undefined for these parameters".to_string())
56-
);
44+
return None;
5745
}
5846
let x = num / den;
59-
return Ok(mu + (x * sigma));
47+
return Some(mu + (x * sigma));
6048
}
6149
let r = if q <= 0.0 { p } else { 1.0 - p };
6250
if r <= 0.0 || r >= 1.0 {
63-
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
51+
return None;
6452
}
6553
let r = (-(r.ln())).sqrt();
54+
let num;
55+
let den;
6656
#[allow(clippy::excessive_precision)]
6757
if r <= 5.0 {
6858
let r = r - 1.6;
@@ -120,12 +110,23 @@ mod statistics {
120110
+ 1.0;
121111
}
122112
if den == 0.0 {
123-
return Err(vm.new_value_error("inv_cdf undefined for these parameters".to_string()));
113+
return None;
124114
}
125115
let mut x = num / den;
126116
if q < 0.0 {
127117
x = -x;
128118
}
129-
Ok(mu + (x * sigma))
119+
Some(mu + (x * sigma))
120+
}
121+
122+
#[pyfunction]
123+
fn _normal_dist_inv_cdf(
124+
p: ArgIntoFloat,
125+
mu: ArgIntoFloat,
126+
sigma: ArgIntoFloat,
127+
vm: &VirtualMachine,
128+
) -> PyResult<f64> {
129+
normal_dist_inv_cdf(p.to_f64(), mu.to_f64(), sigma.to_f64())
130+
.ok_or_else(|| vm.new_value_error("inv_cdf undefined for these parameters".to_owned()))
130131
}
131132
}

0 commit comments

Comments
 (0)