Skip to content

Commit 9a4614c

Browse files
committed
Write test, but it fails. But can't figure out why it fails...
1 parent 6e655ca commit 9a4614c

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

tests/test_unit/test_plot.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from itertools import product
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
from matplotlib.collections import QuadMesh
8+
from numpy.random import RandomState
9+
10+
from movement.plot import occupancy_histogram
11+
12+
13+
def get_histogram_binning_data(fig: plt.Figure) -> list[QuadMesh]:
14+
"""Fetch 2D array data from a histogram plot."""
15+
return [
16+
qm for qm in fig.axes[0].get_children() if isinstance(qm, QuadMesh)
17+
]
18+
19+
20+
@pytest.fixture
21+
def seed() -> int:
22+
return 0
23+
24+
25+
@pytest.fixture(scope="function")
26+
def rng(seed: int) -> RandomState:
27+
"""Create a RandomState to use in testing.
28+
29+
This ensures the repeatability of histogram tests, that require large
30+
datasets that would be tedious to create manually.
31+
"""
32+
return RandomState(seed)
33+
34+
35+
@pytest.fixture
36+
def normal_dist_2d(rng: RandomState) -> np.ndarray:
37+
"""Points distributed by the standard multivariate normal.
38+
39+
The standard multivariate normal is just two independent N(0, 1)
40+
distributions, one in each dimension.
41+
"""
42+
samples = rng.multivariate_normal(
43+
(0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4)
44+
)
45+
return np.moveaxis(
46+
samples, 3, 1
47+
) # Move generated space coords to correct axis position
48+
49+
50+
@pytest.fixture
51+
def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray:
52+
"""DataArray whose data is the ``normal_dist_2d`` points.
53+
54+
Axes 2 and 3 are the individuals and keypoints axes, respectively.
55+
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
56+
the purposes of indexing.
57+
"""
58+
return xr.DataArray(
59+
data=normal_dist_2d,
60+
dims=["time", "space", "individuals", "keypoints"],
61+
coords={
62+
"space": ["x", "y"],
63+
"individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])],
64+
"keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])],
65+
},
66+
)
67+
68+
69+
@pytest.fixture
70+
def histogram_data_with_nans(
71+
histogram_data: xr.DataArray, rng: RandomState
72+
) -> xr.DataArray:
73+
"""DataArray whose data is the ``normal_dist_2d`` points.
74+
75+
Each datapoint has a chance of being turned into a NaN value.
76+
77+
Axes 2 and 3 are the individuals and keypoints axes, respectively.
78+
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
79+
the purposes of indexing.
80+
"""
81+
data_with_nans = histogram_data.copy(deep=True)
82+
data_shape = data_with_nans.shape
83+
nan_chance = 1.0 / 25.0
84+
index_ranges = [range(dim_length) for dim_length in data_shape]
85+
for multiindex in product(*index_ranges):
86+
if rng.uniform() < nan_chance:
87+
data_with_nans[*multiindex] = float("nan")
88+
return data_with_nans
89+
90+
91+
# def test_histogram_ignores_missing_dims(
92+
# input_does_not_have_dimensions: list[str],
93+
# ) -> None:
94+
# """Test that ``occupancy_histogram`` ignores non-present dimensions."""
95+
# input_data = 0
96+
97+
98+
@pytest.mark.parametrize(
99+
["data", "individual", "keypoint", "n_bins"],
100+
[pytest.param("histogram_data", "i0", "k0", 30, id="30 bins each axis")],
101+
)
102+
def test_occupancy_histogram(
103+
data: xr.DataArray,
104+
individual: int | str,
105+
keypoint: int | str,
106+
n_bins: int | tuple[int, int],
107+
request,
108+
) -> None:
109+
"""Test that occupancy histograms correctly plot data."""
110+
if isinstance(data, str):
111+
data = request.getfixturevalue(data)
112+
113+
plotted_hist = occupancy_histogram(
114+
data, individual=individual, keypoint=keypoint, bins=n_bins
115+
)
116+
117+
# Confirm that a histogram was made
118+
plotted_data = get_histogram_binning_data(plotted_hist)
119+
assert len(plotted_data) == 1
120+
plotted_data = plotted_data[0]
121+
plotting_coords = plotted_data.get_coordinates()
122+
plotted_values = plotted_data.get_array()
123+
124+
# Confirm the binned array has the correct size
125+
if not isinstance(n_bins, tuple):
126+
n_bins = (n_bins, n_bins)
127+
assert plotted_data.get_array().shape == n_bins
128+
129+
# Confirm that each bin has the correct number of assignments
130+
data_time_xy = data.sel(individuals=individual, keypoints=keypoint)
131+
x_values = data_time_xy.sel(space="x").values
132+
y_values = data_time_xy.sel(space="y").values
133+
reconstructed_bins_limits_x = np.linspace(
134+
x_values.min(),
135+
x_values.max(),
136+
num=n_bins[0] + 1,
137+
endpoint=True,
138+
)
139+
assert all(
140+
np.allclose(reconstructed_bins_limits_x, plotting_coords[i, :, 0])
141+
for i in range(n_bins[0])
142+
)
143+
reconstructed_bins_limits_y = np.linspace(
144+
y_values.min(),
145+
y_values.max(),
146+
num=n_bins[1] + 1,
147+
endpoint=True,
148+
)
149+
assert all(
150+
np.allclose(reconstructed_bins_limits_y, plotting_coords[:, j, 1])
151+
for j in range(n_bins[1])
152+
)
153+
154+
reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float)
155+
for i, xi in enumerate(reconstructed_bins_limits_x[:-1]):
156+
xi_p1 = reconstructed_bins_limits_x[i + 1]
157+
158+
x_pts_in_range = (x_values >= xi) & (x_values <= xi_p1)
159+
for j, yj in enumerate(reconstructed_bins_limits_y[:-1]):
160+
yj_p1 = reconstructed_bins_limits_y[j + 1]
161+
162+
y_pts_in_range = (y_values >= yj) & (y_values <= yj_p1)
163+
164+
pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum()
165+
reconstructed_bin_counts[i, j] = pts_in_this_bin
166+
167+
if pts_in_this_bin != plotted_values[i, j]:
168+
pass
169+
170+
assert reconstructed_bin_counts.sum() == plotted_values.sum()
171+
assert np.all(reconstructed_bin_counts == plotted_values)

0 commit comments

Comments
 (0)