Skip to content

Commit 25a84d0

Browse files
authored
Make the entire crate compatible with the array reference type (#109)
This PR is large, but is repetitive and contains all of the code necessary to make ndarray-stats compatible with ndarray 0.17. It does the following: - Bump ndarray to 0.17.1 - Bump ndarray-rand to 0.16.0 - Implement the extension traits on ArrayRef instead of ArrayBase - Remove the "storage" generics on the extension traits - Update documentation to replace references to ArrayBase
1 parent 7997e20 commit 25a84d0

File tree

18 files changed

+202
-241
lines changed

18 files changed

+202
-241
lines changed

Cargo.lock

Lines changed: 70 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ repository = "https://github.com/rust-ndarray/ndarray-stats"
1414
documentation = "https://docs.rs/ndarray-stats/"
1515
readme = "README.md"
1616

17-
description = "Statistical routines for ArrayBase, the n-dimensional array data structure provided by ndarray."
17+
description = "Statistical routines for the n-dimensional array data structures provided by ndarray."
1818

1919
keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"]
2020
categories = ["data-structures", "science"]
2121

2222
[dependencies]
23-
ndarray = "0.16.0"
23+
ndarray = "0.17.1"
2424
noisy_float = "0.2.0"
2525
num-integer = "0.1"
2626
num-traits = "0.2"
@@ -29,10 +29,10 @@ itertools = { version = "0.13", default-features = false }
2929
indexmap = "2.4"
3030

3131
[dev-dependencies]
32-
ndarray = { version = "0.16.1", features = ["approx"] }
32+
ndarray = { version = "0.17.1", features = ["approx"] }
3333
criterion = "0.5.1"
3434
quickcheck = { version = "0.9.2", default-features = false }
35-
ndarray-rand = "0.15.0"
35+
ndarray-rand = "0.16.0"
3636
approx = "0.5"
3737
quickcheck_macros = "1.0.0"
3838
num-bigint = "0.4.0"

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[![Crate](https://img.shields.io/crates/v/ndarray-stats.svg)](https://crates.io/crates/ndarray-stats)
66
[![Documentation](https://docs.rs/ndarray-stats/badge.svg)](https://docs.rs/ndarray-stats)
77

8-
This crate provides statistical methods for [`ndarray`]'s `ArrayBase` type.
8+
This crate provides statistical methods for [`ndarray`]'s `ArrayRef` type.
99

1010
Currently available routines include:
1111
- order statistics (minimum, maximum, median, quantiles, etc.);
@@ -32,6 +32,11 @@ ndarray-stats = "0.6.0"
3232

3333
## Releases
3434

35+
* **0.7.0**
36+
37+
* Breaking changes
38+
* Updated to `ndarray:v0.17.1`
39+
3540
* **0.6.0**
3641

3742
* Breaking changes

benches/deviation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn sq_l2_dist(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let data2 = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let data2 = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717

1818
b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap()))
1919
});

benches/summary_statistics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn weighted_std(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let mut weights = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let mut weights = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717
weights /= weights.sum();
1818
b.iter_batched(
1919
|| data.clone(),

src/correlation.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use crate::errors::EmptyInput;
22
use ndarray::prelude::*;
3-
use ndarray::Data;
43
use num_traits::{Float, FromPrimitive};
54

6-
/// Extension trait for `ArrayBase` providing functions
5+
/// Extension trait for `ndarray` providing functions
76
/// to compute different correlation measures.
8-
pub trait CorrelationExt<A, S>
9-
where
10-
S: Data<Elem = A>,
11-
{
7+
pub trait CorrelationExt<A> {
128
/// Return the covariance matrix `C` for a 2-dimensional
139
/// array of observations `M`.
1410
///
@@ -125,10 +121,7 @@ where
125121
private_decl! {}
126122
}
127123

128-
impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
129-
where
130-
S: Data<Elem = A>,
131-
{
124+
impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
132125
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
133126
where
134127
A: Float + FromPrimitive,
@@ -147,7 +140,7 @@ where
147140
let mean = self.mean_axis(observation_axis);
148141
match mean {
149142
Some(mean) => {
150-
let denoised = self - &mean.insert_axis(observation_axis);
143+
let denoised = self - mean.insert_axis(observation_axis);
151144
let covariance = denoised.dot(&denoised.t());
152145
Ok(covariance.mapv_into(|x| x / dof))
153146
}
@@ -208,7 +201,7 @@ mod cov_tests {
208201
let n_observations = 4;
209202
let a = Array::random(
210203
(n_random_variables, n_observations),
211-
Uniform::new(-bound.abs(), bound.abs()),
204+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
212205
);
213206
let covariance = a.cov(1.).unwrap();
214207
abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
@@ -219,7 +212,10 @@ mod cov_tests {
219212
fn test_invalid_ddof() {
220213
let n_random_variables = 3;
221214
let n_observations = 4;
222-
let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
215+
let a = Array::random(
216+
(n_random_variables, n_observations),
217+
Uniform::new(0., 10.).unwrap(),
218+
);
223219
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
224220
let _ = a.cov(invalid_ddof);
225221
}
@@ -299,7 +295,7 @@ mod pearson_correlation_tests {
299295
let n_observations = 4;
300296
let a = Array::random(
301297
(n_random_variables, n_observations),
302-
Uniform::new(-bound.abs(), bound.abs()),
298+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
303299
);
304300
let pearson_correlation = a.pearson_correlation().unwrap();
305301
abs_diff_eq!(

0 commit comments

Comments
 (0)