Skip to content

Commit f3a4d98

Browse files
committed
refactor the comparison for arg reductions
1 parent 5c9d25b commit f3a4d98

File tree

1 file changed

+57
-9
lines changed

1 file changed

+57
-9
lines changed

xarray_array_testing/reduction.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import itertools
12
from contextlib import nullcontext
23

34
import hypothesis.strategies as st
45
import numpy as np
56
import pytest
7+
import xarray as xr
68
import xarray.testing.strategies as xrst
7-
from hypothesis import given
9+
from hypothesis import given, note
810

911
from xarray_array_testing.base import DuckArrayTestMixin
1012

@@ -60,17 +62,63 @@ def test_variable_order_reduce(self, op, data):
6062
@given(st.data())
6163
def test_variable_order_reduce_index(self, op, data):
6264
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
65+
possible_dims = [..., list(variable.dims), *variable.dims] + list(
66+
itertools.chain.from_iterable(
67+
map(list, itertools.combinations(variable.dims, length))
68+
for length in range(1, len(variable.dims))
69+
)
70+
)
71+
dim = data.draw(st.sampled_from(possible_dims))
6372

6473
with self.expected_errors(op, variable=variable):
6574
# compute using xr.Variable.<OP>()
66-
actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()}
67-
68-
# compute using xp.<OP>(array)
69-
index = getattr(self.xp, op)(variable.data)
70-
unraveled = np.unravel_index(index, variable.shape)
71-
expected = dict(zip(variable.dims, unraveled))
72-
73-
self.assert_equal(actual, expected)
75+
actual = getattr(variable, op)(dim=dim)
76+
77+
if dim is not ... and not isinstance(dim, list):
78+
# compute using xp.<OP>(array)
79+
note(dim)
80+
axis = variable.get_axis_num(dim)
81+
expected = getattr(self.xp, op)(variable.data, axis=axis)
82+
self.assert_equal(actual.data, expected)
83+
elif dim is ... or len(dim) == len(variable.dims):
84+
# compute using xp.<OP>(array)
85+
index = getattr(self.xp, op)(variable.data)
86+
87+
unraveled = np.unravel_index(index, variable.shape)
88+
expected = dict(zip(variable.dims, unraveled))
89+
90+
# all elements are 0D
91+
assert actual == expected
92+
else:
93+
if len(dim) == 1:
94+
dim_ = dim[0]
95+
axis = variable.get_axis_num(dim_)
96+
index = getattr(self.xp, op)(variable.data, axis=axis)
97+
98+
result_dims = [d for d in variable.dims if d != dim_]
99+
expected = {dim_: xr.Variable(result_dims, index)}
100+
else:
101+
# move the relevant dims together and flatten
102+
dim_name = object()
103+
stacked = variable.stack({dim_name: dim})
104+
105+
result_dims = stacked.dims[:-1]
106+
reduce_shape = tuple(variable.sizes[d] for d in dim)
107+
index = getattr(self.xp, op)(stacked.data, axis=-1)
108+
109+
unravelled = np.unravel_index(index, reduce_shape)
110+
111+
expected = {
112+
d: xr.Variable(result_dims, idx)
113+
for d, idx in zip(dim, unravelled, strict=True)
114+
}
115+
116+
note(f"original: {variable}")
117+
note(f"actual: {actual}")
118+
note(f"expected: {expected}")
119+
120+
assert actual.keys() == expected.keys(), "Reduction dims are not equal"
121+
assert all(actual[k].equals(expected[k]) for k in actual)
74122

75123
@pytest.mark.parametrize(
76124
"op",

0 commit comments

Comments
 (0)