Skip to content

Commit c7cfd0e

Browse files
committed
Write test coverage
1 parent e1cb1ff commit c7cfd0e

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import re
2+
from typing import Any
3+
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
8+
from movement.utils.dimensions import (
9+
collapse_extra_dimensions,
10+
coord_of_dimension,
11+
)
12+
13+
14+
@pytest.fixture
15+
def shape() -> tuple[int, ...]:
16+
return (7, 2, 3, 4)
17+
18+
19+
@pytest.fixture
20+
def da(shape: tuple[int, ...]) -> xr.DataArray:
21+
return xr.DataArray(
22+
data=np.arange(np.prod(shape)).reshape(shape),
23+
dims=["time", "space", "individuals", "keypoints"],
24+
coords={
25+
"space": ["x", "y"],
26+
"individuals": ["a", "b", "c"],
27+
"keypoints": ["head", "shoulders", "knees", "toes"],
28+
},
29+
)
30+
31+
32+
@pytest.mark.parametrize(
33+
["pass_to_function", "equivalent_to_sel"],
34+
[
35+
pytest.param(
36+
{},
37+
{"individuals": "a", "keypoints": "head"},
38+
id="Default preserve time-space",
39+
),
40+
pytest.param(
41+
{"preserve_dims": ["space"]},
42+
{"time": 0, "individuals": "a", "keypoints": "head"},
43+
id="Keep space only",
44+
),
45+
pytest.param(
46+
{"individuals": 1},
47+
{"individuals": "b", "keypoints": "head"},
48+
id="Request non-default slice",
49+
),
50+
pytest.param(
51+
{"individuals": "c"},
52+
{"individuals": "c", "keypoints": "head"},
53+
id="Request by coordinate",
54+
),
55+
pytest.param(
56+
{
57+
"individuals": 1,
58+
"elephants": "this is a non-existent dimension",
59+
"crabs": 42,
60+
},
61+
{"individuals": "b", "keypoints": "head"},
62+
id="Selection ignores dimensions that don't exist",
63+
),
64+
pytest.param(
65+
{"preserve_dims": []},
66+
{"time": 0, "space": "x", "individuals": "a", "keypoints": "head"},
67+
id="Collapse everything",
68+
),
69+
],
70+
)
71+
def test_collapse_dimensions(
72+
da: xr.DataArray,
73+
pass_to_function: dict[str, Any],
74+
equivalent_to_sel: dict[str, int | str],
75+
) -> None:
76+
result_from_collapsing = collapse_extra_dimensions(da, **pass_to_function)
77+
78+
# We should be equivalent to this method
79+
expected_result = da.sel(**equivalent_to_sel)
80+
81+
assert result_from_collapsing.shape == expected_result.values.shape
82+
xr.testing.assert_allclose(result_from_collapsing, expected_result)
83+
84+
85+
@pytest.mark.parametrize(
86+
["args_to_fn", "expected"],
87+
[
88+
pytest.param(
89+
{"dimension": "individuals", "coord_index": 1},
90+
"b",
91+
id="Fetch coord from index",
92+
),
93+
pytest.param(
94+
{"dimension": "time", "coord_index": 6},
95+
6,
96+
id="Dimension with no coordinates",
97+
),
98+
pytest.param(
99+
{"dimension": "space", "coord_index": "x"},
100+
"x",
101+
id="Fetch coord from name",
102+
),
103+
pytest.param(
104+
{"dimension": "keypoints", "coord_index": 10},
105+
IndexError("index 10 is out of bounds for axis 0 with size 4"),
106+
id="Out of bounds index",
107+
),
108+
pytest.param(
109+
{"dimension": "keypoints", "coord_index": "arms"},
110+
KeyError("arms"),
111+
id="Non existent coord name",
112+
),
113+
],
114+
)
115+
def test_coord_of_dimension(
116+
da: xr.DataArray, args_to_fn: dict[str, Any], expected: str | Exception
117+
) -> None:
118+
if isinstance(expected, Exception):
119+
with pytest.raises(type(expected), match=re.escape(str(expected))):
120+
coord_of_dimension(da, **args_to_fn)
121+
else:
122+
assert expected == coord_of_dimension(da, **args_to_fn)

0 commit comments

Comments
 (0)