Skip to content

Commit fc5b434

Browse files
committed
replace the orthogonal indexing test with the new strategy
1 parent 0f5993e commit fc5b434

File tree

2 files changed

+35
-67
lines changed

2 files changed

+35
-67
lines changed

xarray_array_testing/indexing.py

Lines changed: 33 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,36 @@
11
from contextlib import nullcontext
22

3-
import hypothesis.extra.numpy as npst
43
import hypothesis.strategies as st
5-
import xarray as xr
64
import xarray.testing.strategies as xrst
75
from hypothesis import given
86

97
from xarray_array_testing.base import DuckArrayTestMixin
8+
from xarray_array_testing.strategies import orthogonal_indexers, vectorized_indexers
109

1110

12-
def scalar_indexer(size):
13-
return st.integers(min_value=-size, max_value=size - 1)
11+
def broadcast_orthogonal_indexers(indexers, sizes, *, xp):
12+
def _broadcasting_shape(index, total):
13+
return tuple(1 if i != index else -1 for i in range(total))
1414

15+
def _as_array(indexer, size):
16+
if isinstance(indexer, slice):
17+
return xp.asarray(range(*indexer.indices(size)), dtype="int64")
18+
elif isinstance(indexer, int):
19+
return xp.asarray(indexer, dtype="int64")
20+
else:
21+
return indexer
1522

16-
def integer_array_indexer(size):
17-
dtypes = npst.integer_dtypes()
18-
19-
return npst.arrays(
20-
dtypes, size, elements={"min_value": -size, "max_value": size - 1}
21-
)
22-
23-
24-
def indexers(size, indexer_types):
25-
indexer_strategy_fns = {
26-
"scalars": scalar_indexer,
27-
"slices": st.slices,
28-
"integer_arrays": integer_array_indexer,
23+
indexer_arrays = {
24+
dim: _as_array(indexer, sizes[dim]) for dim, indexer in indexers.items()
2925
}
30-
31-
bad_types = set(indexer_types) - indexer_strategy_fns.keys()
32-
if bad_types:
33-
raise ValueError(f"unknown indexer strategies: {sorted(bad_types)}")
34-
35-
# use the order of definition to prefer simpler strategies over more complex
36-
# ones
37-
indexer_strategies = [
38-
strategy_fn(size)
39-
for name, strategy_fn in indexer_strategy_fns.items()
40-
if name in indexer_types
41-
]
42-
return st.one_of(*indexer_strategies)
43-
44-
45-
@st.composite
46-
def orthogonal_indexers(draw, sizes, indexer_types):
47-
# TODO: make use of `flatmap` and `builds` instead of `composite`
48-
possible_indexers = {
49-
dim: indexers(size, indexer_types) for dim, size in sizes.items()
50-
}
51-
concrete_indexers = draw(xrst.unique_subset_of(possible_indexers))
52-
return {dim: draw(indexer) for dim, indexer in concrete_indexers.items()}
53-
54-
55-
@st.composite
56-
def vectorized_indexers(draw, sizes):
57-
max_size = max(sizes.values())
58-
shape = draw(st.integers(min_value=1, max_value=max_size))
59-
dtypes = npst.integer_dtypes()
60-
61-
indexers = {
62-
dim: npst.arrays(
63-
dtypes, shape, elements={"min_value": -size, "max_value": size - 1}
26+
broadcasted = xp.broadcast_arrays(
27+
*(
28+
xp.reshape(indexer, _broadcasting_shape(index, total=len(indexers)))
29+
for index, indexer in enumerate(indexer_arrays.values())
6430
)
65-
for dim, size in sizes.items()
66-
}
31+
)
6732

68-
return {
69-
dim: xr.Variable("points", draw(indexer)) for dim, indexer in indexers.items()
70-
}
33+
return dict(zip(indexer_arrays.keys(), broadcasted))
7134

7235

7336
class IndexingTests(DuckArrayTestMixin):
@@ -81,19 +44,24 @@ def expected_errors(op, **parameters):
8144

8245
@given(st.data())
8346
def test_variable_isel_orthogonal(self, data):
84-
indexer_types = data.draw(
85-
st.lists(self.orthogonal_indexer_types, min_size=1, unique=True)
86-
)
8747
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
88-
idx = data.draw(orthogonal_indexers(variable.sizes, indexer_types))
48+
idx = data.draw(
49+
orthogonal_indexers(sizes=variable.sizes, min_dims=len(variable.dims))
50+
)
8951

90-
with self.expected_errors(
91-
"isel_orthogonal", variable=variable, indexer_types=indexer_types
92-
):
52+
with self.expected_errors("isel_orthogonal", variable=variable, indexers=idx):
9353
actual = variable.isel(idx).data
9454

95-
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
96-
expected = variable.data[*raw_indexers.values()]
55+
sorted_dims = sorted(idx.keys(), key=variable.dims.index, reverse=True)
56+
expected = variable.data
57+
for dim in sorted_dims:
58+
indexer = idx[dim]
59+
axis = variable.get_axis_num(dim)
60+
if isinstance(indexer, slice):
61+
indexer = self.xp.asarray(
62+
range(*indexer.indices(variable.sizes[dim])), dtype="int64"
63+
)
64+
expected = self.xp.take(expected, indexer, axis=axis)
9765

9866
assert isinstance(
9967
actual, self.array_type("orthogonal_indexing")

xarray_array_testing/strategies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections.abc import Hashable
22
from itertools import compress
33

4-
import hypothesis.extras.numpy as npst
4+
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
66
import numpy as np
77
import xarray as xr
8-
from xr.testing.strategies import unique_subset_of
8+
from xarray.testing.strategies import unique_subset_of
99

1010

1111
def _basic_indexers(size):

0 commit comments

Comments
 (0)