Skip to content

Commit 7cbbd56

Browse files
committed
refactor to always use the indexer equals function
1 parent 0f336d3 commit 7cbbd56

File tree

1 file changed

+35
-36
lines changed

1 file changed

+35
-36
lines changed

xarray_array_testing/reduction.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import hypothesis.strategies as st
55
import numpy as np
66
import pytest
7-
import xarray as xr
87
import xarray.testing.strategies as xrst
98
from hypothesis import given, note
109

@@ -73,52 +72,52 @@ def test_variable_order_reduce_index(self, op, data):
7372
with self.expected_errors(op, variable=variable):
7473
# compute using xr.Variable.<OP>()
7574
actual = getattr(variable, op)(dim=dim)
75+
if dim is ... or isinstance(dim, list):
76+
actual_ = {dim_: var.data for dim_, var in actual.items()}
77+
else:
78+
actual_ = actual.data
7679

7780
if dim is not ... and not isinstance(dim, list):
7881
# compute using xp.<OP>(array)
7982
note(dim)
8083
axis = variable.get_axis_num(dim)
81-
expected = getattr(self.xp, op)(variable.data, axis=axis)
82-
self.assert_equal(actual.data, expected)
84+
indices = getattr(self.xp, op)(variable.data, axis=axis)
85+
86+
expected = self.xp.asarray(indices)
8387
elif dim is ... or len(dim) == len(variable.dims):
8488
# compute using xp.<OP>(array)
8589
index = getattr(self.xp, op)(variable.data)
8690

8791
unraveled = np.unravel_index(index, variable.shape)
88-
expected = dict(zip(variable.dims, unraveled))
89-
90-
# all elements are 0D
91-
assert actual == expected
92+
expected = {
93+
k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled)
94+
}
95+
elif len(dim) == 1:
96+
dim_ = dim[0]
97+
axis = variable.get_axis_num(dim_)
98+
index = getattr(self.xp, op)(variable.data, axis=axis)
99+
100+
expected = {dim_: self.xp.asarray(index)}
92101
else:
93-
if len(dim) == 1:
94-
dim_ = dim[0]
95-
axis = variable.get_axis_num(dim_)
96-
index = getattr(self.xp, op)(variable.data, axis=axis)
97-
98-
result_dims = [d for d in variable.dims if d != dim_]
99-
expected = {dim_: xr.Variable(result_dims, index)}
100-
else:
101-
# move the relevant dims together and flatten
102-
dim_name = object()
103-
stacked = variable.stack({dim_name: dim})
104-
105-
result_dims = stacked.dims[:-1]
106-
reduce_shape = tuple(variable.sizes[d] for d in dim)
107-
index = getattr(self.xp, op)(stacked.data, axis=-1)
108-
109-
unravelled = np.unravel_index(index, reduce_shape)
110-
111-
expected = {
112-
d: xr.Variable(result_dims, idx)
113-
for d, idx in zip(dim, unravelled, strict=True)
114-
}
115-
116-
note(f"original: {variable}")
117-
note(f"actual: {actual}")
118-
note(f"expected: {expected}")
119-
120-
assert actual.keys() == expected.keys(), "Reduction dims are not equal"
121-
assert all(actual[k].equals(expected[k]) for k in actual)
102+
# move the relevant dims together and flatten
103+
dim_name = object()
104+
stacked = variable.stack({dim_name: dim})
105+
106+
reduce_shape = tuple(variable.sizes[d] for d in dim)
107+
index = getattr(self.xp, op)(stacked.data, axis=-1)
108+
109+
unravelled = np.unravel_index(index, reduce_shape)
110+
111+
expected = {
112+
d: self.xp.asarray(idx)
113+
for d, idx in zip(dim, unravelled, strict=True)
114+
}
115+
116+
note(f"original: {variable}")
117+
note(f"actual: {repr(actual_)}")
118+
note(f"expected: {repr(expected)}")
119+
120+
self.assert_dimension_indexers_equal(actual_, expected)
122121

123122
@pytest.mark.parametrize(
124123
"op",

0 commit comments

Comments
 (0)