forked from xarray-contrib/xarray-array-testing
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathindexing.py
More file actions
117 lines (89 loc) · 3.68 KB
/
indexing.py
File metadata and controls
117 lines (89 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from contextlib import nullcontext
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import xarray as xr
import xarray.testing.strategies as xrst
from hypothesis import given
from xarray_array_testing.base import DuckArrayTestMixin
def scalar_indexer(size):
return st.integers(min_value=-size, max_value=size - 1)
def integer_array_indexer(size):
dtypes = npst.integer_dtypes()
return npst.arrays(
dtypes, size, elements={"min_value": -size, "max_value": size - 1}
)
def indexers(size, indexer_types):
indexer_strategy_fns = {
"scalars": scalar_indexer,
"slices": st.slices,
"integer_arrays": integer_array_indexer,
}
bad_types = set(indexer_types) - indexer_strategy_fns.keys()
if bad_types:
raise ValueError(f"unknown indexer strategies: {sorted(bad_types)}")
# use the order of definition to prefer simpler strategies over more complex
# ones
indexer_strategies = [
strategy_fn(size)
for name, strategy_fn in indexer_strategy_fns.items()
if name in indexer_types
]
return st.one_of(*indexer_strategies)
@st.composite
def orthogonal_indexers(draw, sizes, indexer_types):
# TODO: make use of `flatmap` and `builds` instead of `composite`
possible_indexers = {
dim: indexers(size, indexer_types) for dim, size in sizes.items()
}
concrete_indexers = draw(xrst.unique_subset_of(possible_indexers))
return {dim: draw(indexer) for dim, indexer in concrete_indexers.items()}
@st.composite
def vectorized_indexers(draw, sizes):
max_size = max(sizes.values())
shape = draw(st.integers(min_value=1, max_value=max_size))
dtypes = npst.integer_dtypes()
indexers = {
dim: npst.arrays(
dtypes, shape, elements={"min_value": -size, "max_value": size - 1}
)
for dim, size in sizes.items()
}
return {
dim: xr.Variable("points", draw(indexer)) for dim, indexer in indexers.items()
}
class IndexingTests(DuckArrayTestMixin):
@property
def orthogonal_indexer_types(self):
return st.sampled_from(["scalars", "slices"])
@staticmethod
def expected_errors(op, **parameters):
return nullcontext()
@given(st.data())
def test_variable_isel_orthogonal(self, data):
indexer_types = data.draw(
st.lists(self.orthogonal_indexer_types, min_size=1, unique=True)
)
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
idx = data.draw(orthogonal_indexers(variable.sizes, indexer_types))
with self.expected_errors(
"isel_orthogonal", variable=variable, indexer_types=indexer_types
):
actual = variable.isel(idx).data
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
expected = variable.data[*raw_indexers.values()]
assert isinstance(
actual, self.array_type("orthogonal_indexing")
), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)
@given(st.data())
def test_variable_isel_vectorized(self, data):
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
idx = data.draw(vectorized_indexers(variable.sizes))
with self.expected_errors("isel_vectorized", variable=variable):
actual = variable.isel(idx).data
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
expected = variable.data[*raw_indexers.values()]
assert isinstance(
actual, self.array_type("vectorized_indexing")
), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)