Skip to content

Commit 0a3b95f

Browse files
Remove Option<Axis> handling since axis is always required
1 parent a60bec2 commit 0a3b95f

File tree

2 files changed

+31
-72
lines changed

2 files changed

+31
-72
lines changed

src/numeric/impl_numeric.rs

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ where
103103

104104
/// Return the cumulative product of elements along a given axis.
105105
///
106-
/// If `axis` is None, the array is flattened before taking the cumulative product.
107-
///
108106
/// ```
109107
/// use ndarray::{arr2, Axis};
110108
///
@@ -113,62 +111,43 @@ where
113111
///
114112
/// // Cumulative product along rows (axis 0)
115113
/// assert_eq!(
116-
/// a.cumprod(Some(Axis(0))),
114+
/// a.cumprod(Axis(0)),
117115
/// arr2(&[[1., 2., 3.],
118116
/// [4., 10., 18.]])
119117
/// );
120118
///
121119
/// // Cumulative product along columns (axis 1)
122120
/// assert_eq!(
123-
/// a.cumprod(Some(Axis(1))),
121+
/// a.cumprod(Axis(1)),
124122
/// arr2(&[[1., 2., 6.],
125123
/// [4., 20., 120.]])
126124
/// );
127125
/// ```
128126
///
129127
/// **Panics** if `axis` is out of bounds.
130128
#[track_caller]
131-
pub fn cumprod(&self, axis: Option<Axis>) -> Array<A, D>
129+
pub fn cumprod(&self, axis: Axis) -> Array<A, D>
132130
where
133131
A: Clone + One + Mul<Output = A> + ScalarOperand,
134132
D: Dimension + RemoveAxis,
135133
{
136-
// First check dimensionality
137-
if self.ndim() > 1 && axis.is_none() {
138-
panic!("axis parameter is required for arrays with more than one dimension");
134+
// Check if axis is valid before any array operations
135+
if axis.0 >= self.ndim() {
136+
panic!("axis is out of bounds for array of dimension");
139137
}
140138

141139
let mut res = Array::ones(self.raw_dim());
140+
let mut acc = Array::ones(self.raw_dim().remove_axis(axis));
142141

143-
match axis {
144-
None => {
145-
// For 1D arrays, use simple iteration
146-
let mut acc = A::one();
147-
Zip::from(&mut res).and(self).for_each(|r, x| {
148-
acc = acc.clone() * x.clone();
149-
*r = acc.clone();
150-
});
151-
res
152-
}
153-
Some(axis) => {
154-
// Check if axis is valid before any array operations
155-
if axis.0 >= self.ndim() {
156-
panic!("axis is out of bounds for array of dimension");
157-
}
158-
159-
// For nD arrays, use fold_axis approach
160-
// Create accumulator array with one less dimension
161-
let mut acc = Array::ones(self.raw_dim().remove_axis(axis));
162-
163-
for i in 0..self.len_of(axis) {
164-
// Get view of current slice along axis, and update accumulator element-wise multiplication
165-
let view = self.index_axis(axis, i);
166-
acc = acc * &view;
167-
res.index_axis_mut(axis, i).assign(&acc);
168-
}
169-
res
170-
}
142+
// Use fold_axis approach
143+
for i in 0..self.len_of(axis) {
144+
// Get view of current slice along axis, and update accumulator element-wise multiplication
145+
let view = self.index_axis(axis, i);
146+
acc = acc * &view;
147+
res.index_axis_mut(axis, i).assign(&acc);
171148
}
149+
150+
res
172151
}
173152

174153
/// Return variance of elements in the array.

tests/numeric.rs

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,19 @@ fn sum_mean_prod_empty()
7979
fn test_cumprod_1d()
8080
{
8181
let a = array![1, 2, 3, 4];
82-
// For 1D arrays, both None and Some(Axis(0)) should work
83-
let result_none = a.cumprod(None);
84-
let result_axis = a.cumprod(Some(Axis(0)));
85-
assert_eq!(result_none, array![1, 2, 6, 24]);
86-
assert_eq!(result_axis, array![1, 2, 6, 24]);
82+
let result = a.cumprod(Axis(0));
83+
assert_eq!(result, array![1, 2, 6, 24]);
8784
}
8885

8986
#[test]
9087
fn test_cumprod_2d()
9188
{
9289
let a = array![[1, 2], [3, 4]];
9390

94-
// For 2D arrays, we must specify an axis
95-
let result_axis0 = a.cumprod(Some(Axis(0)));
91+
let result_axis0 = a.cumprod(Axis(0));
9692
assert_eq!(result_axis0, array![[1, 2], [3, 8]]);
9793

98-
let result_axis1 = a.cumprod(Some(Axis(1)));
94+
let result_axis1 = a.cumprod(Axis(1));
9995
assert_eq!(result_axis1, array![[1, 2], [3, 12]]);
10096
}
10197

@@ -104,30 +100,24 @@ fn test_cumprod_3d()
104100
{
105101
let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];
106102

107-
// For 3D arrays, we must specify an axis
108-
let result_axis0 = a.cumprod(Some(Axis(0)));
103+
let result_axis0 = a.cumprod(Axis(0));
109104
assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]);
110105

111-
let result_axis1 = a.cumprod(Some(Axis(1)));
106+
let result_axis1 = a.cumprod(Axis(1));
112107
assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]);
113108

114-
let result_axis2 = a.cumprod(Some(Axis(2)));
109+
let result_axis2 = a.cumprod(Axis(2));
115110
assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]);
116111
}
117112

118113
#[test]
119114
fn test_cumprod_empty()
120115
{
121-
// For 1D empty array
122-
let a: Array1<i32> = array![];
123-
let result = a.cumprod(None);
124-
assert_eq!(result, array![]);
125-
126-
// For 2D empty array, must specify axis
116+
// For 2D empty array
127117
let b: Array2<i32> = Array2::zeros((0, 0));
128-
let result_axis0 = b.cumprod(Some(Axis(0)));
118+
let result_axis0 = b.cumprod(Axis(0));
129119
assert_eq!(result_axis0, Array2::zeros((0, 0)));
130-
let result_axis1 = b.cumprod(Some(Axis(1)));
120+
let result_axis1 = b.cumprod(Axis(1));
131121
assert_eq!(result_axis1, Array2::zeros((0, 0)));
132122
}
133123

@@ -136,33 +126,23 @@ fn test_cumprod_1_element()
136126
{
137127
// For 1D array with one element
138128
let a = array![5];
139-
let result_none = a.cumprod(None);
140-
let result_axis = a.cumprod(Some(Axis(0)));
141-
assert_eq!(result_none, array![5]);
142-
assert_eq!(result_axis, array![5]);
129+
let result = a.cumprod(Axis(0));
130+
assert_eq!(result, array![5]);
143131

144-
// For 2D array with one element, must specify axis
132+
// For 2D array with one element
145133
let b = array![[5]];
146-
let result_axis0 = b.cumprod(Some(Axis(0)));
147-
let result_axis1 = b.cumprod(Some(Axis(1)));
134+
let result_axis0 = b.cumprod(Axis(0));
135+
let result_axis1 = b.cumprod(Axis(1));
148136
assert_eq!(result_axis0, array![[5]]);
149137
assert_eq!(result_axis1, array![[5]]);
150138
}
151139

152-
#[test]
153-
#[should_panic(expected = "axis parameter is required for arrays with more than one dimension")]
154-
fn test_cumprod_nd_none_axis()
155-
{
156-
let a = array![[1, 2], [3, 4]];
157-
let _result = a.cumprod(None);
158-
}
159-
160140
#[test]
161141
#[should_panic(expected = "axis is out of bounds for array of dimension")]
162142
fn test_cumprod_axis_out_of_bounds()
163143
{
164144
let a = array![[1, 2], [3, 4]];
165-
let _result = a.cumprod(Some(Axis(2)));
145+
let _result = a.cumprod(Axis(2));
166146
}
167147

168148
#[test]

0 commit comments

Comments
 (0)