Skip to content

Commit 8e07666

Browse files
authored
improve support for argreduction comparisons (#42)
* refactor the comparison for arg reductions * add a static method to compare indexing overators * simplify * parametrize by `xp` This requires converting the static method to an actual method. * refactor to always use the indexer equals function * move the note displaying the value of `dim`
1 parent 84ca335 commit 8e07666

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

xarray_array_testing/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,15 @@ def array_strategy_fn(*, shape, dtype):
2424
@staticmethod
2525
def assert_equal(a, b):
2626
npt.assert_equal(a, b)
27+
28+
def assert_dimension_indexers_equal(self, a, b):
29+
assert type(a) is type(b), f"types don't match: {type(a)} vs {type(b)}"
30+
31+
if isinstance(a, dict):
32+
assert a.keys() == b.keys(), f"Different dimensions: {list(a)} vs {list(b)}"
33+
34+
assert all(
35+
self.xp.all(self.xp.equal(a[k], b[k])) for k in a
36+
), "Differing indexers"
37+
else:
38+
npt.assert_equal(a, b)

xarray_array_testing/reduction.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import itertools
12
from contextlib import nullcontext
23

34
import hypothesis.strategies as st
45
import numpy as np
56
import pytest
67
import xarray.testing.strategies as xrst
7-
from hypothesis import given
8+
from hypothesis import given, note
89

910
from xarray_array_testing.base import DuckArrayTestMixin
1011

@@ -60,17 +61,63 @@ def test_variable_order_reduce(self, op, data):
6061
@given(st.data())
6162
def test_variable_order_reduce_index(self, op, data):
6263
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
64+
possible_dims = [..., list(variable.dims), *variable.dims] + list(
65+
itertools.chain.from_iterable(
66+
map(list, itertools.combinations(variable.dims, length))
67+
for length in range(1, len(variable.dims))
68+
)
69+
)
70+
dim = data.draw(st.sampled_from(possible_dims))
6371

6472
with self.expected_errors(op, variable=variable):
6573
# compute using xr.Variable.<OP>()
66-
actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()}
67-
68-
# compute using xp.<OP>(array)
69-
index = getattr(self.xp, op)(variable.data)
70-
unraveled = np.unravel_index(index, variable.shape)
71-
expected = dict(zip(variable.dims, unraveled))
72-
73-
self.assert_equal(actual, expected)
74+
actual = getattr(variable, op)(dim=dim)
75+
if dim is ... or isinstance(dim, list):
76+
actual_ = {dim_: var.data for dim_, var in actual.items()}
77+
else:
78+
actual_ = actual.data
79+
80+
note(f"dim: {dim}")
81+
if dim is not ... and not isinstance(dim, list):
82+
# compute using xp.<OP>(array)
83+
axis = variable.get_axis_num(dim)
84+
indices = getattr(self.xp, op)(variable.data, axis=axis)
85+
86+
expected = self.xp.asarray(indices)
87+
elif dim is ... or len(dim) == len(variable.dims):
88+
# compute using xp.<OP>(array)
89+
index = getattr(self.xp, op)(variable.data)
90+
91+
unraveled = np.unravel_index(index, variable.shape)
92+
expected = {
93+
k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled)
94+
}
95+
elif len(dim) == 1:
96+
dim_ = dim[0]
97+
axis = variable.get_axis_num(dim_)
98+
index = getattr(self.xp, op)(variable.data, axis=axis)
99+
100+
expected = {dim_: self.xp.asarray(index)}
101+
else:
102+
# move the relevant dims together and flatten
103+
dim_name = object()
104+
stacked = variable.stack({dim_name: dim})
105+
106+
reduce_shape = tuple(variable.sizes[d] for d in dim)
107+
index = getattr(self.xp, op)(stacked.data, axis=-1)
108+
109+
unravelled = np.unravel_index(index, reduce_shape)
110+
111+
expected = {
112+
d: self.xp.asarray(idx)
113+
for d, idx in zip(dim, unravelled, strict=True)
114+
}
115+
116+
note(f"original: {variable}")
117+
note(f"actual: {repr(actual_)}")
118+
note(f"expected: {repr(expected)}")
119+
120+
self.assert_dimension_indexers_equal(actual_, expected)
74121

75122
@pytest.mark.parametrize(
76123
"op",

0 commit comments

Comments
 (0)