66import xarray as xr
77
88from movement .utils .dimensions import (
9+ _coord_of_dimension ,
910 collapse_extra_dimensions ,
10- coord_of_dimension ,
1111)
1212
1313
1414@pytest .fixture
1515def shape () -> tuple [int , ...]:
16- return (7 , 2 , 3 , 4 )
16+ return (7 , 2 , 4 , 3 )
1717
1818
1919@pytest .fixture
2020def da (shape : tuple [int , ...]) -> xr .DataArray :
2121 return xr .DataArray (
2222 data = np .arange (np .prod (shape )).reshape (shape ),
23- dims = ["time" , "space" , "individuals " , "keypoints " ],
23+ dims = ["time" , "space" , "keypoints " , "individuals " ],
2424 coords = {
2525 "space" : ["x" , "y" ],
26- "individuals" : ["a" , "b" , "c" ],
2726 "keypoints" : ["head" , "shoulders" , "knees" , "toes" ],
27+ "individuals" : ["a" , "b" , "c" ],
2828 },
2929 )
3030
@@ -43,7 +43,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
4343 id = "Keep space only" ,
4444 ),
4545 pytest .param (
46- {"individuals" : 1 },
46+ {"individuals" : "b" },
4747 {"individuals" : "b" , "keypoints" : "head" },
4848 id = "Request non-default slice" ,
4949 ),
@@ -54,7 +54,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
5454 ),
5555 pytest .param (
5656 {
57- "individuals" : 1 ,
57+ "individuals" : "b" ,
5858 "elephants" : "this is a non-existent dimension" ,
5959 "crabs" : 42 ,
6060 },
@@ -70,7 +70,7 @@ def da(shape: tuple[int, ...]) -> xr.DataArray:
7070)
7171def test_collapse_dimensions (
7272 da : xr .DataArray ,
73- pass_to_function : dict [str , Any ],
73+ pass_to_function : dict [str , str ],
7474 equivalent_to_sel : dict [str , int | str ],
7575) -> None :
7676 result_from_collapsing = collapse_extra_dimensions (da , ** pass_to_function )
@@ -82,6 +82,27 @@ def test_collapse_dimensions(
8282 xr .testing .assert_allclose (result_from_collapsing , expected_result )
8383
8484
85+ @pytest .mark .parametrize (
86+ ["pass_to_function" ],
87+ [
88+ pytest .param (
89+ {"keypoints" : ["head" , "toes" ]},
90+ id = "Multiple keypoints" ,
91+ ),
92+ pytest .param (
93+ {"individuals" : ["a" , "b" ]},
94+ id = "Multiple individuals" ,
95+ ),
96+ ],
97+ )
98+ def test_collapse_dimensions_value_error (
99+ da : xr .DataArray ,
100+ pass_to_function : dict [str , Any ],
101+ ) -> None :
102+ with pytest .raises (ValueError ):
103+ collapse_extra_dimensions (da , ** pass_to_function )
104+
105+
85106@pytest .mark .parametrize (
86107 ["args_to_fn" , "expected" ],
87108 [
@@ -113,10 +134,10 @@ def test_collapse_dimensions(
113134 ],
114135)
115136def test_coord_of_dimension (
116- da : xr .DataArray , args_to_fn : dict [str , Any ], expected : str | Exception
137+ da : xr .DataArray , args_to_fn : dict [str , str ], expected : str | Exception
117138) -> None :
118139 if isinstance (expected , Exception ):
119140 with pytest .raises (type (expected ), match = re .escape (str (expected ))):
120- coord_of_dimension (da , ** args_to_fn )
141+ _coord_of_dimension (da , ** args_to_fn )
121142 else :
122- assert expected == coord_of_dimension (da , ** args_to_fn )
143+ assert expected == _coord_of_dimension (da , ** args_to_fn )
0 commit comments