Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use std::ops::{Add, Div, Mul, Sub};

use crate::imp_prelude::*;
use crate::numeric_util;
use crate::ScalarOperand;
use crate::Slice;
use crate::Zip;

/// # Numerical Methods for Arrays
impl<A, S, D> ArrayBase<S, D>
Expand Down Expand Up @@ -99,6 +101,76 @@ where
sum
}

/// Return the cumulative product of elements along a given axis.
///
/// If `axis` is None, the array is flattened before taking the cumulative product.
///
/// ```
/// use ndarray::{arr2, Axis};
///
/// let a = arr2(&[[1., 2., 3.],
/// [4., 5., 6.]]);
///
/// // Cumulative product along rows (axis 0)
/// assert_eq!(
/// a.cumprod(Some(Axis(0))),
/// arr2(&[[1., 2., 3.],
/// [4., 10., 18.]])
/// );
///
/// // Cumulative product along columns (axis 1)
/// assert_eq!(
/// a.cumprod(Some(Axis(1))),
/// arr2(&[[1., 2., 6.],
/// [4., 20., 120.]])
/// );
/// ```
///
/// **Panics** if `axis` is out of bounds.
#[track_caller]
pub fn cumprod(&self, axis: Option<Axis>) -> Array<A, D>
where
A: Clone + One + Mul<Output = A> + ScalarOperand,
D: Dimension + RemoveAxis,
{
// First check dimensionality
if self.ndim() > 1 && axis.is_none() {
panic!("axis parameter is required for arrays with more than one dimension");
}

let mut res = Array::ones(self.raw_dim());

match axis {
None => {
// For 1D arrays, use simple iteration
let mut acc = A::one();
Zip::from(&mut res).and(self).for_each(|r, x| {
acc = acc.clone() * x.clone();
*r = acc.clone();
});
res
}
Some(axis) => {
// Check if axis is valid before any array operations
if axis.0 >= self.ndim() {
panic!("axis is out of bounds for array of dimension");
}

// For nD arrays, use fold_axis approach
// Create accumulator array with one less dimension
let mut acc = Array::ones(self.raw_dim().remove_axis(axis));

for i in 0..self.len_of(axis) {
// Get view of current slice along axis, and update accumulator element-wise multiplication
let view = self.index_axis(axis, i);
acc = acc * &view;
res.index_axis_mut(axis, i).assign(&acc);
}
res
}
}
}

/// Return variance of elements in the array.
///
/// The variance is computed using the [Welford one-pass
Expand Down
90 changes: 90 additions & 0 deletions tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,96 @@ fn sum_mean_prod_empty()
assert_eq!(a, None);
}

#[test]
fn test_cumprod_1d()
{
let a = array![1, 2, 3, 4];
// For 1D arrays, both None and Some(Axis(0)) should work
let result_none = a.cumprod(None);
let result_axis = a.cumprod(Some(Axis(0)));
assert_eq!(result_none, array![1, 2, 6, 24]);
assert_eq!(result_axis, array![1, 2, 6, 24]);
}

#[test]
fn test_cumprod_2d()
{
let a = array![[1, 2], [3, 4]];

// For 2D arrays, we must specify an axis
let result_axis0 = a.cumprod(Some(Axis(0)));
assert_eq!(result_axis0, array![[1, 2], [3, 8]]);

let result_axis1 = a.cumprod(Some(Axis(1)));
assert_eq!(result_axis1, array![[1, 2], [3, 12]]);
}

#[test]
fn test_cumprod_3d()
{
let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];

// For 3D arrays, we must specify an axis
let result_axis0 = a.cumprod(Some(Axis(0)));
assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]);

let result_axis1 = a.cumprod(Some(Axis(1)));
assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]);

let result_axis2 = a.cumprod(Some(Axis(2)));
assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]);
}

#[test]
fn test_cumprod_empty()
{
// For 1D empty array
let a: Array1<i32> = array![];
let result = a.cumprod(None);
assert_eq!(result, array![]);

// For 2D empty array, must specify axis
let b: Array2<i32> = Array2::zeros((0, 0));
let result_axis0 = b.cumprod(Some(Axis(0)));
assert_eq!(result_axis0, Array2::zeros((0, 0)));
let result_axis1 = b.cumprod(Some(Axis(1)));
assert_eq!(result_axis1, Array2::zeros((0, 0)));
}

#[test]
fn test_cumprod_1_element()
{
// For 1D array with one element
let a = array![5];
let result_none = a.cumprod(None);
let result_axis = a.cumprod(Some(Axis(0)));
assert_eq!(result_none, array![5]);
assert_eq!(result_axis, array![5]);

// For 2D array with one element, must specify axis
let b = array![[5]];
let result_axis0 = b.cumprod(Some(Axis(0)));
let result_axis1 = b.cumprod(Some(Axis(1)));
assert_eq!(result_axis0, array![[5]]);
assert_eq!(result_axis1, array![[5]]);
}

#[test]
#[should_panic(expected = "axis parameter is required for arrays with more than one dimension")]
fn test_cumprod_nd_none_axis()
{
let a = array![[1, 2], [3, 4]];
let _result = a.cumprod(None);
}

#[test]
#[should_panic(expected = "axis is out of bounds for array of dimension")]
fn test_cumprod_axis_out_of_bounds()
{
let a = array![[1, 2], [3, 4]];
let _result = a.cumprod(Some(Axis(2)));
}

#[test]
#[cfg(feature = "std")]
fn var()
Expand Down
Loading