6
6
import xarray as xr
7
7
8
8
from movement .utils .dimensions import (
9
+ _coord_of_dimension ,
9
10
collapse_extra_dimensions ,
10
- coord_of_dimension ,
11
11
)
12
12
13
13
14
14
@pytest .fixture
15
15
def shape () -> tuple [int , ...]:
16
- return (7 , 2 , 3 , 4 )
16
+ return (7 , 2 , 4 , 3 )
17
17
18
18
19
19
@pytest .fixture
20
20
def da (shape : tuple [int , ...]) -> xr .DataArray :
21
21
return xr .DataArray (
22
22
data = np .arange (np .prod (shape )).reshape (shape ),
23
- dims = ["time" , "space" , "individuals " , "keypoints " ],
23
+ dims = ["time" , "space" , "keypoints " , "individuals " ],
24
24
coords = {
25
25
"space" : ["x" , "y" ],
26
- "individuals" : ["a" , "b" , "c" ],
27
26
"keypoints" : ["head" , "shoulders" , "knees" , "toes" ],
27
+ "individuals" : ["a" , "b" , "c" ],
28
28
},
29
29
)
30
30
@@ -43,7 +43,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
43
43
id = "Keep space only" ,
44
44
),
45
45
pytest .param (
46
- {"individuals" : 1 },
46
+ {"individuals" : "b" },
47
47
{"individuals" : "b" , "keypoints" : "head" },
48
48
id = "Request non-default slice" ,
49
49
),
@@ -54,7 +54,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
54
54
),
55
55
pytest .param (
56
56
{
57
- "individuals" : 1 ,
57
+ "individuals" : "b" ,
58
58
"elephants" : "this is a non-existent dimension" ,
59
59
"crabs" : 42 ,
60
60
},
@@ -70,7 +70,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
70
70
)
71
71
def test_collapse_dimensions (
72
72
da : xr .DataArray ,
73
- pass_to_function : dict [str , Any ],
73
+ pass_to_function : dict [str , str ],
74
74
equivalent_to_sel : dict [str , int | str ],
75
75
) -> None :
76
76
result_from_collapsing = collapse_extra_dimensions (da , ** pass_to_function )
@@ -82,6 +82,27 @@ def test_collapse_dimensions(
82
82
xr .testing .assert_allclose (result_from_collapsing , expected_result )
83
83
84
84
85
+ @pytest .mark .parametrize (
86
+ ["pass_to_function" ],
87
+ [
88
+ pytest .param (
89
+ {"keypoints" : ["head" , "toes" ]},
90
+ id = "Multiple keypoints" ,
91
+ ),
92
+ pytest .param (
93
+ {"individuals" : ["a" , "b" ]},
94
+ id = "Multiple individuals" ,
95
+ ),
96
+ ],
97
+ )
98
+ def test_collapse_dimensions_value_error (
99
+ da : xr .DataArray ,
100
+ pass_to_function : dict [str , Any ],
101
+ ) -> None :
102
+ with pytest .raises (ValueError ):
103
+ collapse_extra_dimensions (da , ** pass_to_function )
104
+
105
+
85
106
@pytest .mark .parametrize (
86
107
["args_to_fn" , "expected" ],
87
108
[
@@ -113,10 +134,10 @@ def test_collapse_dimensions(
113
134
],
114
135
)
115
136
def test_coord_of_dimension (
116
- da : xr .DataArray , args_to_fn : dict [str , Any ], expected : str | Exception
137
+ da : xr .DataArray , args_to_fn : dict [str , str ], expected : str | Exception
117
138
) -> None :
118
139
if isinstance (expected , Exception ):
119
140
with pytest .raises (type (expected ), match = re .escape (str (expected ))):
120
- coord_of_dimension (da , ** args_to_fn )
141
+ _coord_of_dimension (da , ** args_to_fn )
121
142
else :
122
- assert expected == coord_of_dimension (da , ** args_to_fn )
143
+ assert expected == _coord_of_dimension (da , ** args_to_fn )
0 commit comments