11from contextlib import nullcontext
22
3- import hypothesis .extra .numpy as npst
43import hypothesis .strategies as st
5- import xarray as xr
64import xarray .testing .strategies as xrst
75from hypothesis import given
86
97from xarray_array_testing .base import DuckArrayTestMixin
8+ from xarray_array_testing .strategies import orthogonal_indexers , vectorized_indexers
109
1110
12- def scalar_indexer (size ):
13- return st .integers (min_value = - size , max_value = size - 1 )
11+ def broadcast_orthogonal_indexers (indexers , sizes , * , xp ):
12+ def _broadcasting_shape (index , total ):
13+ return tuple (1 if i != index else - 1 for i in range (total ))
1414
15+ def _as_array (indexer , size ):
16+ if isinstance (indexer , slice ):
17+ return xp .asarray (range (* indexer .indices (size )), dtype = "int64" )
18+ elif isinstance (indexer , int ):
19+ return xp .asarray (indexer , dtype = "int64" )
20+ else :
21+ return indexer
1522
16- def integer_array_indexer (size ):
17- dtypes = npst .integer_dtypes ()
18-
19- return npst .arrays (
20- dtypes , size , elements = {"min_value" : - size , "max_value" : size - 1 }
21- )
22-
23-
24- def indexers (size , indexer_types ):
25- indexer_strategy_fns = {
26- "scalars" : scalar_indexer ,
27- "slices" : st .slices ,
28- "integer_arrays" : integer_array_indexer ,
23+ indexer_arrays = {
24+ dim : _as_array (indexer , sizes [dim ]) for dim , indexer in indexers .items ()
2925 }
30-
31- bad_types = set (indexer_types ) - indexer_strategy_fns .keys ()
32- if bad_types :
33- raise ValueError (f"unknown indexer strategies: { sorted (bad_types )} " )
34-
35- # use the order of definition to prefer simpler strategies over more complex
36- # ones
37- indexer_strategies = [
38- strategy_fn (size )
39- for name , strategy_fn in indexer_strategy_fns .items ()
40- if name in indexer_types
41- ]
42- return st .one_of (* indexer_strategies )
43-
44-
45- @st .composite
46- def orthogonal_indexers (draw , sizes , indexer_types ):
47- # TODO: make use of `flatmap` and `builds` instead of `composite`
48- possible_indexers = {
49- dim : indexers (size , indexer_types ) for dim , size in sizes .items ()
50- }
51- concrete_indexers = draw (xrst .unique_subset_of (possible_indexers ))
52- return {dim : draw (indexer ) for dim , indexer in concrete_indexers .items ()}
53-
54-
55- @st .composite
56- def vectorized_indexers (draw , sizes ):
57- max_size = max (sizes .values ())
58- shape = draw (st .integers (min_value = 1 , max_value = max_size ))
59- dtypes = npst .integer_dtypes ()
60-
61- indexers = {
62- dim : npst .arrays (
63- dtypes , shape , elements = {"min_value" : - size , "max_value" : size - 1 }
26+ broadcasted = xp .broadcast_arrays (
27+ * (
28+ xp .reshape (indexer , _broadcasting_shape (index , total = len (indexers )))
29+ for index , indexer in enumerate (indexer_arrays .values ())
6430 )
65- for dim , size in sizes .items ()
66- }
31+ )
6732
68- return {
69- dim : xr .Variable ("points" , draw (indexer )) for dim , indexer in indexers .items ()
70- }
33+ return dict (zip (indexer_arrays .keys (), broadcasted ))
7134
7235
7336class IndexingTests (DuckArrayTestMixin ):
@@ -81,19 +44,24 @@ def expected_errors(op, **parameters):
8144
8245 @given (st .data ())
8346 def test_variable_isel_orthogonal (self , data ):
84- indexer_types = data .draw (
85- st .lists (self .orthogonal_indexer_types , min_size = 1 , unique = True )
86- )
8747 variable = data .draw (xrst .variables (array_strategy_fn = self .array_strategy_fn ))
88- idx = data .draw (orthogonal_indexers (variable .sizes , indexer_types ))
48+ idx = data .draw (
49+ orthogonal_indexers (sizes = variable .sizes , min_dims = len (variable .dims ))
50+ )
8951
90- with self .expected_errors (
91- "isel_orthogonal" , variable = variable , indexer_types = indexer_types
92- ):
52+ with self .expected_errors ("isel_orthogonal" , variable = variable , indexers = idx ):
9353 actual = variable .isel (idx ).data
9454
95- raw_indexers = {dim : idx .get (dim , slice (None )) for dim in variable .dims }
96- expected = variable .data [* raw_indexers .values ()]
55+ sorted_dims = sorted (idx .keys (), key = variable .dims .index , reverse = True )
56+ expected = variable .data
57+ for dim in sorted_dims :
58+ indexer = idx [dim ]
59+ axis = variable .get_axis_num (dim )
60+ if isinstance (indexer , slice ):
61+ indexer = self .xp .asarray (
62+ range (* indexer .indices (variable .sizes [dim ])), dtype = "int64"
63+ )
64+ expected = self .xp .take (expected , indexer , axis = axis )
9765
9866 assert isinstance (
9967 actual , self .array_type ("orthogonal_indexing" )
0 commit comments