|
4 | 4 | import hypothesis.strategies as st |
5 | 5 | import numpy as np |
6 | 6 | import pytest |
7 | | -import xarray as xr |
8 | 7 | import xarray.testing.strategies as xrst |
9 | 8 | from hypothesis import given, note |
10 | 9 |
|
@@ -73,52 +72,52 @@ def test_variable_order_reduce_index(self, op, data): |
73 | 72 | with self.expected_errors(op, variable=variable): |
74 | 73 | # compute using xr.Variable.<OP>() |
75 | 74 | actual = getattr(variable, op)(dim=dim) |
| 75 | + if dim is ... or isinstance(dim, list): |
| 76 | + actual_ = {dim_: var.data for dim_, var in actual.items()} |
| 77 | + else: |
| 78 | + actual_ = actual.data |
76 | 79 |
|
77 | 80 | if dim is not ... and not isinstance(dim, list): |
78 | 81 | # compute using xp.<OP>(array) |
79 | 82 | note(dim) |
80 | 83 | axis = variable.get_axis_num(dim) |
81 | | - expected = getattr(self.xp, op)(variable.data, axis=axis) |
82 | | - self.assert_equal(actual.data, expected) |
| 84 | + indices = getattr(self.xp, op)(variable.data, axis=axis) |
| 85 | + |
| 86 | + expected = self.xp.asarray(indices) |
83 | 87 | elif dim is ... or len(dim) == len(variable.dims): |
84 | 88 | # compute using xp.<OP>(array) |
85 | 89 | index = getattr(self.xp, op)(variable.data) |
86 | 90 |
|
87 | 91 | unraveled = np.unravel_index(index, variable.shape) |
88 | | - expected = dict(zip(variable.dims, unraveled)) |
89 | | - |
90 | | - # all elements are 0D |
91 | | - assert actual == expected |
| 92 | + expected = { |
| 93 | + k: self.xp.asarray(v) for k, v in zip(variable.dims, unraveled) |
| 94 | + } |
| 95 | + elif len(dim) == 1: |
| 96 | + dim_ = dim[0] |
| 97 | + axis = variable.get_axis_num(dim_) |
| 98 | + index = getattr(self.xp, op)(variable.data, axis=axis) |
| 99 | + |
| 100 | + expected = {dim_: self.xp.asarray(index)} |
92 | 101 | else: |
93 | | - if len(dim) == 1: |
94 | | - dim_ = dim[0] |
95 | | - axis = variable.get_axis_num(dim_) |
96 | | - index = getattr(self.xp, op)(variable.data, axis=axis) |
97 | | - |
98 | | - result_dims = [d for d in variable.dims if d != dim_] |
99 | | - expected = {dim_: xr.Variable(result_dims, index)} |
100 | | - else: |
101 | | - # move the relevant dims together and flatten |
102 | | - dim_name = object() |
103 | | - stacked = variable.stack({dim_name: dim}) |
104 | | - |
105 | | - result_dims = stacked.dims[:-1] |
106 | | - reduce_shape = tuple(variable.sizes[d] for d in dim) |
107 | | - index = getattr(self.xp, op)(stacked.data, axis=-1) |
108 | | - |
109 | | - unravelled = np.unravel_index(index, reduce_shape) |
110 | | - |
111 | | - expected = { |
112 | | - d: xr.Variable(result_dims, idx) |
113 | | - for d, idx in zip(dim, unravelled, strict=True) |
114 | | - } |
115 | | - |
116 | | - note(f"original: {variable}") |
117 | | - note(f"actual: {actual}") |
118 | | - note(f"expected: {expected}") |
119 | | - |
120 | | - assert actual.keys() == expected.keys(), "Reduction dims are not equal" |
121 | | - assert all(actual[k].equals(expected[k]) for k in actual) |
| 102 | + # move the relevant dims together and flatten |
| 103 | + dim_name = object() |
| 104 | + stacked = variable.stack({dim_name: dim}) |
| 105 | + |
| 106 | + reduce_shape = tuple(variable.sizes[d] for d in dim) |
| 107 | + index = getattr(self.xp, op)(stacked.data, axis=-1) |
| 108 | + |
| 109 | + unravelled = np.unravel_index(index, reduce_shape) |
| 110 | + |
| 111 | + expected = { |
| 112 | + d: self.xp.asarray(idx) |
| 113 | + for d, idx in zip(dim, unravelled, strict=True) |
| 114 | + } |
| 115 | + |
| 116 | + note(f"original: {variable}") |
| 117 | + note(f"actual: {repr(actual_)}") |
| 118 | + note(f"expected: {repr(expected)}") |
| 119 | + |
| 120 | + self.assert_dimension_indexers_equal(actual_, expected) |
122 | 121 |
|
123 | 122 | @pytest.mark.parametrize( |
124 | 123 | "op", |
|
0 commit comments