1
1
from contextlib import nullcontext
2
2
3
+ import hypothesis .extra .numpy as npst
3
4
import hypothesis .strategies as st
4
5
import xarray .testing .strategies as xrst
5
6
from hypothesis import given
@@ -11,39 +12,65 @@ def scalar_indexer(size):
11
12
return st .integers (min_value = - size , max_value = size - 1 )
12
13
13
14
15
+ def integer_array_indexer (size ):
16
+ dtypes = npst .integer_dtypes ()
17
+
18
+ return npst .arrays (
19
+ dtypes , size , elements = {"min_value" : - size , "max_value" : size - 1 }
20
+ )
21
+
22
+
23
+ def indexers (size , indexer_types ):
24
+ indexer_strategy_fns = {
25
+ "scalars" : scalar_indexer ,
26
+ "slices" : st .slices ,
27
+ "integer_arrays" : integer_array_indexer ,
28
+ }
29
+
30
+ bad_types = set (indexer_types ) - indexer_strategy_fns .keys ()
31
+ if bad_types :
32
+ raise ValueError (f"unknown indexer strategies: { sorted (bad_types )} " )
33
+
34
+ # use the order of definition to prefer simpler strategies over more complex
35
+ # ones
36
+ indexer_strategies = [
37
+ strategy_fn (size )
38
+ for name , strategy_fn in indexer_strategy_fns .items ()
39
+ if name in indexer_types
40
+ ]
41
+ return st .one_of (* indexer_strategies )
42
+
43
+
14
44
@st .composite
15
- def indexers (draw , sizes , indexer_strategy_fn ):
45
+ def orthogonal_indexers (draw , sizes , indexer_types ):
16
46
# TODO: make use of `flatmap` and `builds` instead of `composite`
17
- possible_indexers = {dim : indexer_strategy_fn (size ) for dim , size in sizes .items ()}
18
- indexers = draw (xrst .unique_subset_of (possible_indexers ))
19
- return {dim : draw (indexer ) for dim , indexer in indexers .items ()}
47
+ possible_indexers = {
48
+ dim : indexers (size , indexer_types ) for dim , size in sizes .items ()
49
+ }
50
+ concrete_indexers = draw (xrst .unique_subset_of (possible_indexers ))
51
+ return {dim : draw (indexer ) for dim , indexer in concrete_indexers .items ()}
20
52
21
53
22
54
class IndexingTests (DuckArrayTestMixin ):
55
+ @property
56
+ def orthogonal_indexer_types (self ):
57
+ return st .sampled_from (["scalars" , "slices" ])
58
+
23
59
@staticmethod
24
60
def expected_errors (op , ** parameters ):
25
61
return nullcontext ()
26
62
27
63
@given (st .data ())
28
- def test_variable_isel_scalars (self , data ):
29
- variable = data .draw (xrst .variables (array_strategy_fn = self .array_strategy_fn ))
30
- idx = data .draw (indexers (variable .sizes , scalar_indexer ))
31
-
32
- with self .expected_errors ("isel_scalars" , variable = variable ):
33
- actual = variable .isel (idx ).data
34
-
35
- raw_indexers = {dim : idx .get (dim , slice (None )) for dim in variable .dims }
36
- expected = variable .data [* raw_indexers .values ()]
37
-
38
- assert isinstance (actual , self .array_type ), f"wrong type: { type (actual )} "
39
- self .assert_equal (actual , expected )
40
-
41
- @given (st .data ())
42
- def test_variable_isel_slices (self , data ):
64
+ def test_variable_isel_orthogonal (self , data ):
65
+ indexer_types = data .draw (
66
+ st .lists (self .orthogonal_indexer_types , min_size = 1 , unique = True )
67
+ )
43
68
variable = data .draw (xrst .variables (array_strategy_fn = self .array_strategy_fn ))
44
- idx = data .draw (indexers (variable .sizes , st . slices ))
69
+ idx = data .draw (orthogonal_indexers (variable .sizes , indexer_types ))
45
70
46
- with self .expected_errors ("isel_slices" , variable = variable ):
71
+ with self .expected_errors (
72
+ "isel_orthogonal" , variable = variable , indexer_types = indexer_types
73
+ ):
47
74
actual = variable .isel (idx ).data
48
75
49
76
raw_indexers = {dim : idx .get (dim , slice (None )) for dim in variable .dims }
0 commit comments