Skip to content

Commit addf7e8

Browse files
committed
vendor the indexer strategies from xarray
1 parent 8e07666 commit addf7e8

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

xarray_array_testing/strategies.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from collections.abc import Hashable
2+
from itertools import compress
3+
4+
import hypothesis.extras.numpy as npst
5+
import hypothesis.strategies as st
6+
import numpy as np
7+
import xarray as xr
8+
from xr.testing.strategies import unique_subset_of
9+
10+
11+
# vendored from `xarray`, should be included in `xarray>=2026.01.0`
12+
@st.composite
13+
def basic_indexers(
14+
draw,
15+
/,
16+
*,
17+
sizes: dict[Hashable, int],
18+
min_dims: int = 1,
19+
max_dims: int | None = None,
20+
) -> dict[Hashable, int | slice]:
21+
"""Generate basic indexers using ``hypothesis.extra.numpy.basic_indices``.
22+
23+
Parameters
24+
----------
25+
draw : callable
26+
sizes : dict[Hashable, int]
27+
Dictionary mapping dimension names to their sizes.
28+
min_dims : int, optional
29+
Minimum number of dimensions to index.
30+
max_dims : int or None, optional
31+
Maximum number of dimensions to index.
32+
33+
Returns
34+
-------
35+
sizes : mapping of hashable to int or slice
36+
Indexers as a dict with keys randomly selected from ``sizes.keys()``.
37+
38+
See Also
39+
--------
40+
hypothesis.strategies.slices
41+
"""
42+
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))
43+
44+
# Generate one basic index (int or slice) per selected dimension
45+
idxr = {
46+
dim: draw(
47+
st.one_of(
48+
st.integers(min_value=-size, max_value=size - 1),
49+
st.slices(size),
50+
)
51+
)
52+
for dim, size in selected_dims.items()
53+
}
54+
return idxr
55+
56+
57+
@st.composite
58+
def outer_array_indexers(
59+
draw,
60+
/,
61+
*,
62+
sizes: dict[Hashable, int],
63+
min_dims: int = 0,
64+
max_dims: int | None = None,
65+
max_size: int = 10,
66+
) -> dict[Hashable, np.ndarray]:
67+
"""Generate outer array indexers (vectorized/orthogonal indexing).
68+
69+
Parameters
70+
----------
71+
draw : callable
72+
The Hypothesis draw function (automatically provided by @st.composite).
73+
sizes : dict[Hashable, int]
74+
Dictionary mapping dimension names to their sizes.
75+
min_dims : int, optional
76+
Minimum number of dimensions to index
77+
max_dims : int or None, optional
78+
Maximum number of dimensions to index
79+
80+
Returns
81+
-------
82+
sizes : mapping of hashable to np.ndarray
83+
Indexers as a dict with keys randomly selected from ``sizes.keys()``.
84+
Values are 1D numpy arrays of integer indices for each dimension.
85+
86+
See Also
87+
--------
88+
hypothesis.extra.numpy.arrays
89+
"""
90+
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))
91+
idxr = {
92+
dim: draw(
93+
npst.arrays(
94+
dtype=np.int64,
95+
shape=st.integers(min_value=1, max_value=min(size, max_size)),
96+
elements=st.integers(min_value=-size, max_value=size - 1),
97+
)
98+
)
99+
for dim, size in selected_dims.items()
100+
}
101+
return idxr
102+
103+
104+
@st.composite
105+
def vectorized_indexers(
106+
draw,
107+
/,
108+
*,
109+
sizes: dict[Hashable, int],
110+
min_dims: int = 2,
111+
max_dims: int | None = None,
112+
min_ndim: int = 1,
113+
max_ndim: int = 3,
114+
min_size: int = 1,
115+
max_size: int = 5,
116+
) -> dict[Hashable, xr.DataArray]:
117+
"""Generate vectorized (fancy) indexers where all arrays are broadcastable.
118+
119+
In vectorized indexing, all array indexers must have compatible shapes
120+
that can be broadcast together, and the result shape is determined by
121+
broadcasting the indexer arrays.
122+
123+
Parameters
124+
----------
125+
draw : callable
126+
The Hypothesis draw function (automatically provided by @st.composite).
127+
sizes : dict[Hashable, int]
128+
Dictionary mapping dimension names to their sizes.
129+
min_dims : int, optional
130+
Minimum number of dimensions to index. Default is 2, so that we always have a "trajectory".
131+
Use ``outer_array_indexers`` for the ``min_dims==1`` case.
132+
max_dims : int or None, optional
133+
Maximum number of dimensions to index.
134+
min_ndim : int, optional
135+
Minimum number of dimensions for the result arrays.
136+
max_ndim : int, optional
137+
Maximum number of dimensions for the result arrays.
138+
min_size : int, optional
139+
Minimum size for each dimension in the result arrays.
140+
max_size : int, optional
141+
Maximum size for each dimension in the result arrays.
142+
143+
Returns
144+
-------
145+
sizes : mapping of hashable to DataArray or Variable
146+
Indexers as a dict with keys randomly selected from sizes.keys().
147+
Values are DataArrays of integer indices that are all broadcastable
148+
to a common shape.
149+
150+
See Also
151+
--------
152+
hypothesis.extra.numpy.arrays
153+
"""
154+
selected_dims = draw(unique_subset_of(sizes, min_size=min_dims, max_size=max_dims))
155+
156+
# Generate a common broadcast shape for all arrays
157+
# Use min_ndim to max_ndim dimensions for the result shape
158+
result_shape = draw(
159+
st.lists(
160+
st.integers(min_value=min_size, max_value=max_size),
161+
min_size=min_ndim,
162+
max_size=max_ndim,
163+
)
164+
)
165+
result_ndim = len(result_shape)
166+
167+
# Create dimension names for the vectorized result
168+
vec_dims = tuple(f"vec_{i}" for i in range(result_ndim))
169+
170+
# Generate array indexers for each selected dimension
171+
# All arrays must be broadcastable to the same result_shape
172+
idxr = {}
173+
for dim, size in selected_dims.items():
174+
array_shape = draw(
175+
npst.broadcastable_shapes(
176+
shape=tuple(result_shape),
177+
min_dims=min_ndim,
178+
max_dims=result_ndim,
179+
)
180+
)
181+
182+
# For xarray broadcasting, drop dimensions where size differs from result_shape
183+
# (numpy broadcasts size-1, but xarray requires matching sizes or missing dims)
184+
# Right-align array_shape with result_shape for comparison
185+
aligned_dims = vec_dims[-len(array_shape) :] if array_shape else ()
186+
aligned_result = result_shape[-len(array_shape) :] if array_shape else []
187+
keep_mask = [s == r for s, r in zip(array_shape, aligned_result, strict=True)]
188+
filtered_shape = tuple(compress(array_shape, keep_mask))
189+
filtered_dims = tuple(compress(aligned_dims, keep_mask))
190+
191+
# Generate array of valid indices for this dimension
192+
indices = draw(
193+
npst.arrays(
194+
dtype=np.int64,
195+
shape=filtered_shape,
196+
elements=st.integers(min_value=-size, max_value=size - 1),
197+
)
198+
)
199+
idxr[dim] = xr.Variable(data=indices, dims=filtered_dims)
200+
return idxr

0 commit comments

Comments
 (0)