|
| 1 | +import itertools |
1 | 2 | from contextlib import nullcontext |
2 | 3 |
|
3 | 4 | import hypothesis.strategies as st |
4 | 5 | import numpy as np |
5 | 6 | import pytest |
6 | 7 | import xarray.testing.strategies as xrst |
7 | | -from hypothesis import given |
| 8 | +from hypothesis import given, note |
8 | 9 |
|
9 | 10 | from xarray_array_testing.base import DuckArrayTestMixin |
10 | 11 |
|
@@ -60,17 +61,63 @@ def test_variable_order_reduce(self, op, data): |
60 | 61 | @given(st.data()) |
61 | 62 | def test_variable_order_reduce_index(self, op, data): |
62 | 63 | variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) |
| 64 | + possible_dims = [..., list(variable.dims), *variable.dims] + list( |
| 65 | + itertools.chain.from_iterable( |
| 66 | + map(list, itertools.combinations(variable.dims, length)) |
| 67 | + for length in range(1, len(variable.dims)) |
| 68 | + ) |
| 69 | + ) |
| 70 | + dim = data.draw(st.sampled_from(possible_dims)) |
63 | 71 |
|
64 | 72 | with self.expected_errors(op, variable=variable): |
65 | 73 | # compute using xr.Variable.<OP>() |
66 | | - actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()} |
67 | | - |
68 | | - # compute using xp.<OP>(array) |
69 | | - index = getattr(self.xp, op)(variable.data) |
70 | | - unraveled = np.unravel_index(index, variable.shape) |
71 | | - expected = dict(zip(variable.dims, unraveled)) |
72 | | - |
73 | | - self.assert_equal(actual, expected) |
| 74 | + actual = getattr(variable, op)(dim=dim) |
| 75 | + if dim is ... or isinstance(dim, list): |
| 76 | + actual_ = {dim_: var.data for dim_, var in actual.items()} |
| 77 | + else: |
| 78 | + actual_ = actual.data |
| 79 | + |
| 80 | + note(f"dim: {dim}") |
| 81 | + if dim is not ... and not isinstance(dim, list): |
| 82 | + # compute using xp.<OP>(array) |
| 83 | + axis = variable.get_axis_num(dim) |
| 84 | + indices = getattr(self.xp, op)(variable.data, axis=axis) |
| 85 | + |
| 86 | + expected = self.xp.asarray(indices) |
| 87 | + elif dim is ... or len(dim) == len(variable.dims): |
| 88 | + # compute using xp.<OP>(array) |
| 89 | + index = getattr(self.xp, op)(variable.data) |
| 90 | + |
| 91 | + unraveled = np.unravel_index(index, variable.shape) |
| 92 | + expected = { |
| 93 | + k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled) |
| 94 | + } |
| 95 | + elif len(dim) == 1: |
| 96 | + dim_ = dim[0] |
| 97 | + axis = variable.get_axis_num(dim_) |
| 98 | + index = getattr(self.xp, op)(variable.data, axis=axis) |
| 99 | + |
| 100 | + expected = {dim_: self.xp.asarray(index)} |
| 101 | + else: |
| 102 | + # move the relevant dims together and flatten |
| 103 | + dim_name = object() |
| 104 | + stacked = variable.stack({dim_name: dim}) |
| 105 | + |
| 106 | + reduce_shape = tuple(variable.sizes[d] for d in dim) |
| 107 | + index = getattr(self.xp, op)(stacked.data, axis=-1) |
| 108 | + |
| 109 | + unravelled = np.unravel_index(index, reduce_shape) |
| 110 | + |
| 111 | + expected = { |
| 112 | + d: self.xp.asarray(idx) |
| 113 | + for d, idx in zip(dim, unravelled, strict=True) |
| 114 | + } |
| 115 | + |
| 116 | + note(f"original: {variable}") |
| 117 | + note(f"actual: {repr(actual_)}") |
| 118 | + note(f"expected: {repr(expected)}") |
| 119 | + |
| 120 | + self.assert_dimension_indexers_equal(actual_, expected) |
74 | 121 |
|
75 | 122 | @pytest.mark.parametrize( |
76 | 123 | "op", |
|
0 commit comments