Skip to content

Commit 0307411

Browse files
implement cumprod, add tests
1 parent 1866e91 commit 0307411

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

src/numeric/impl_numeric.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::ops::{Add, Div, Mul, Sub};
1414

1515
use crate::imp_prelude::*;
1616
use crate::numeric_util;
17+
use crate::ScalarOperand;
1718
use crate::Slice;
1819

1920
/// # Numerical Methods for Arrays
@@ -99,6 +100,73 @@ where
99100
sum
100101
}
101102

103+
/// Return the cumulative product of elements along a given axis.
104+
///
105+
/// If `axis` is None, the array is flattened before taking the cumulative product.
106+
///
107+
/// ```
108+
/// use ndarray::{arr2, Axis};
109+
///
110+
/// let a = arr2(&[[1., 2., 3.],
111+
/// [4., 5., 6.]]);
112+
///
113+
/// // Cumulative product along rows (axis 0)
114+
/// assert_eq!(
115+
/// a.cumprod(Some(Axis(0))),
116+
/// arr2(&[[1., 2., 3.],
117+
/// [4., 10., 18.]])
118+
/// );
119+
///
120+
/// // Cumulative product along columns (axis 1)
121+
/// assert_eq!(
122+
/// a.cumprod(Some(Axis(1))),
123+
/// arr2(&[[1., 2., 6.],
124+
/// [4., 20., 120.]])
125+
/// );
126+
/// ```
127+
///
128+
/// **Panics** if `axis` is out of bounds.
129+
#[track_caller]
130+
pub fn cumprod(&self, axis: Option<Axis>) -> Array<A, D>
131+
where
132+
A: Clone + One + Mul<Output = A> + ScalarOperand,
133+
D: Dimension + RemoveAxis,
134+
{
135+
// First check dimensionality
136+
if self.ndim() > 1 && axis.is_none() {
137+
panic!("axis parameter is required for arrays with more than one dimension");
138+
}
139+
140+
match axis {
141+
None => {
142+
// This case now only happens for 1D arrays
143+
let mut res = Array::ones(self.raw_dim());
144+
let mut acc = A::one();
145+
146+
for (r, x) in res.iter_mut().zip(self.iter()) {
147+
acc = acc * x.clone();
148+
*r = acc.clone();
149+
}
150+
151+
res
152+
}
153+
Some(axis) => {
154+
let mut res: Array<A, D> = Array::ones(self.raw_dim());
155+
156+
// Process each lane independently
157+
for (mut out_lane, in_lane) in res.lanes_mut(axis).into_iter().zip(self.lanes(axis)) {
158+
let mut acc = A::one();
159+
for (r, x) in out_lane.iter_mut().zip(in_lane.iter()) {
160+
acc = acc * x.clone();
161+
*r = acc.clone();
162+
}
163+
}
164+
165+
res
166+
}
167+
}
168+
}
169+
102170
/// Return variance of elements in the array.
103171
///
104172
/// The variance is computed using the [Welford one-pass

tests/numeric.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
)]
55

66
use approx::assert_abs_diff_eq;
7-
use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Axis};
7+
use ndarray::{arr0, arr1, arr2, arr3, array, aview1, Array, Array1, Array2, Array3, Axis};
88
use std::f64;
99

1010
#[test]
@@ -75,6 +75,96 @@ fn sum_mean_prod_empty()
7575
assert_eq!(a, None);
7676
}
7777

78+
#[test]
79+
fn test_cumprod_1d()
80+
{
81+
let a = array![1, 2, 3, 4];
82+
// For 1D arrays, both None and Some(Axis(0)) should work
83+
let result_none = a.cumprod(None);
84+
let result_axis = a.cumprod(Some(Axis(0)));
85+
assert_eq!(result_none, array![1, 2, 6, 24]);
86+
assert_eq!(result_axis, array![1, 2, 6, 24]);
87+
}
88+
89+
#[test]
90+
fn test_cumprod_2d()
91+
{
92+
let a = array![[1, 2], [3, 4]];
93+
94+
// For 2D arrays, we must specify an axis
95+
let result_axis0 = a.cumprod(Some(Axis(0)));
96+
assert_eq!(result_axis0, array![[1, 2], [3, 8]]);
97+
98+
let result_axis1 = a.cumprod(Some(Axis(1)));
99+
assert_eq!(result_axis1, array![[1, 2], [3, 12]]);
100+
}
101+
102+
#[test]
103+
fn test_cumprod_3d()
104+
{
105+
let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];
106+
107+
// For 3D arrays, we must specify an axis
108+
let result_axis0 = a.cumprod(Some(Axis(0)));
109+
assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]);
110+
111+
let result_axis1 = a.cumprod(Some(Axis(1)));
112+
assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]);
113+
114+
let result_axis2 = a.cumprod(Some(Axis(2)));
115+
assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]);
116+
}
117+
118+
#[test]
119+
fn test_cumprod_empty()
120+
{
121+
// For 1D empty array
122+
let a: Array1<i32> = array![];
123+
let result = a.cumprod(None);
124+
assert_eq!(result, array![]);
125+
126+
// For 2D empty array, must specify axis
127+
let b: Array2<i32> = Array2::zeros((0, 0));
128+
let result_axis0 = b.cumprod(Some(Axis(0)));
129+
assert_eq!(result_axis0, Array2::zeros((0, 0)));
130+
let result_axis1 = b.cumprod(Some(Axis(1)));
131+
assert_eq!(result_axis1, Array2::zeros((0, 0)));
132+
}
133+
134+
#[test]
135+
fn test_cumprod_1_element()
136+
{
137+
// For 1D array with one element
138+
let a = array![5];
139+
let result_none = a.cumprod(None);
140+
let result_axis = a.cumprod(Some(Axis(0)));
141+
assert_eq!(result_none, array![5]);
142+
assert_eq!(result_axis, array![5]);
143+
144+
// For 2D array with one element, must specify axis
145+
let b = array![[5]];
146+
let result_axis0 = b.cumprod(Some(Axis(0)));
147+
let result_axis1 = b.cumprod(Some(Axis(1)));
148+
assert_eq!(result_axis0, array![[5]]);
149+
assert_eq!(result_axis1, array![[5]]);
150+
}
151+
152+
#[test]
153+
#[should_panic(expected = "axis parameter is required for arrays with more than one dimension")]
154+
fn test_cumprod_nd_none_axis()
155+
{
156+
let a = array![[1, 2], [3, 4]];
157+
let _result = a.cumprod(None);
158+
}
159+
160+
#[test]
161+
#[should_panic(expected = "index out of bounds")]
162+
fn test_cumprod_axis_out_of_bounds()
163+
{
164+
let a = array![[1, 2], [3, 4]];
165+
let _result = a.cumprod(Some(Axis(2)));
166+
}
167+
78168
#[test]
79169
#[cfg(feature = "std")]
80170
fn var()

0 commit comments

Comments
 (0)