Skip to content

Commit 4fa1514

Browse files
committed
Test missing dims and entirely NAN values
1 parent a9b9a0f commit 4fa1514

File tree

1 file changed

+116
-41
lines changed

1 file changed

+116
-41
lines changed

tests/test_unit/test_plot.py

Lines changed: 116 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,50 +88,115 @@ def histogram_data_with_nans(
8888
return data_with_nans
8989

9090

91-
# def test_histogram_ignores_missing_dims(
92-
# input_does_not_have_dimensions: list[str],
93-
# ) -> None:
94-
# """Test that ``occupancy_histogram`` ignores non-present dimensions."""
95-
# input_data = 0
91+
@pytest.fixture
92+
def entirely_nan_data(histogram_data: xr.DataArray) -> xr.DataArray:
93+
return histogram_data.copy(
94+
deep=True, data=histogram_data.values * float("nan")
95+
)
9696

9797

9898
@pytest.mark.parametrize(
99-
["data", "individual", "keypoint", "n_bins"],
99+
[
100+
"data",
101+
"remove_dims_from_data_before_starting",
102+
"individual",
103+
"keypoint",
104+
"n_bins",
105+
],
100106
[
101107
pytest.param(
102108
"histogram_data",
109+
[],
103110
"i0",
104111
"k0",
105112
30,
106113
id="30 bins each axis",
107114
),
108115
pytest.param(
109116
"histogram_data",
117+
[],
110118
"i1",
111119
"k0",
112120
(20, 30),
113121
id="(20, 30) bins",
114122
),
115123
pytest.param(
116124
"histogram_data_with_nans",
125+
[],
117126
"i0",
118127
"k0",
119128
30,
120129
id="NaNs should be removed",
121130
),
131+
pytest.param(
132+
"entirely_nan_data",
133+
[],
134+
"i0",
135+
"k0",
136+
10,
137+
id="All NaN-data",
138+
),
139+
pytest.param(
140+
"histogram_data",
141+
["individuals"],
142+
"i0",
143+
"k0",
144+
30,
145+
id="Ignores individual if not a dimension",
146+
),
147+
pytest.param(
148+
"histogram_data",
149+
["keypoints"],
150+
"i0",
151+
"k1",
152+
30,
153+
id="Ignores keypoint if not a dimension",
154+
),
155+
pytest.param(
156+
"histogram_data",
157+
["individuals", "keypoints"],
158+
"i0",
159+
"k0",
160+
30,
161+
id="Can handle raw xy data",
162+
),
122163
],
123164
)
124165
def test_occupancy_histogram(
125166
data: xr.DataArray,
167+
remove_dims_from_data_before_starting: list[str],
126168
individual: int | str,
127169
keypoint: int | str,
128170
n_bins: int | tuple[int, int],
129171
request,
130172
) -> None:
131-
"""Test that occupancy histograms correctly plot data."""
173+
"""Test that occupancy histograms correctly plot data.
174+
175+
Specifically, check that:
176+
- The bin edges are what we expect.
177+
- The bin counts can be manually verified and are in agreement.
178+
- Only non-NaN values are plotted, but NaN values do not throw errors.
179+
"""
132180
if isinstance(data, str):
133181
data = request.getfixturevalue(data)
134182

183+
# We will need to only select the xy data later in the test,
184+
# but if we are dropping dimensions we might need to call it
185+
# in different ways.
186+
kwargs_to_select_xy_data = {
187+
"individuals": individual,
188+
"keypoints": keypoint,
189+
}
190+
for d in remove_dims_from_data_before_starting:
191+
# Retain the 0th value in the corresponding dimension,
192+
# then drop that dimension.
193+
data = data.sel({d: getattr(data, d)[0]}).squeeze()
194+
assert d not in data.dims
195+
196+
# We no longer need to filter this dimension out
197+
# when examining the xy data later in the test.
198+
kwargs_to_select_xy_data.pop(d, None)
199+
135200
_, histogram_info = occupancy_histogram(
136201
data, individual=individual, keypoint=keypoint, bins=n_bins
137202
)
@@ -143,50 +208,60 @@ def test_occupancy_histogram(
143208
assert plotted_values.shape == n_bins
144209

145210
# Confirm that each bin has the correct number of assignments
146-
data_time_xy = data.sel(individuals=individual, keypoints=keypoint)
211+
data_time_xy = data.sel(**kwargs_to_select_xy_data)
147212
data_time_xy = data_time_xy.dropna(dim="time", how="any")
148213
plotted_x_values = data_time_xy.sel(space="x").values
149214
plotted_y_values = data_time_xy.sel(space="y").values
150215
assert plotted_x_values.shape == plotted_y_values.shape
151216
# This many non-NaN values were plotted
152217
n_non_nan_values = plotted_x_values.shape[0]
153218

154-
reconstructed_bins_limits_x = np.linspace(
155-
plotted_x_values.min(),
156-
plotted_x_values.max(),
157-
num=n_bins[0] + 1,
158-
endpoint=True,
159-
)
160-
assert np.allclose(reconstructed_bins_limits_x, histogram_info["xedges"])
161-
reconstructed_bins_limits_y = np.linspace(
162-
plotted_y_values.min(),
163-
plotted_y_values.max(),
164-
num=n_bins[1] + 1,
165-
endpoint=True,
166-
)
167-
assert np.allclose(reconstructed_bins_limits_y, histogram_info["yedges"])
219+
if n_non_nan_values > 0:
220+
reconstructed_bins_limits_x = np.linspace(
221+
plotted_x_values.min(),
222+
plotted_x_values.max(),
223+
num=n_bins[0] + 1,
224+
endpoint=True,
225+
)
226+
assert np.allclose(
227+
reconstructed_bins_limits_x, histogram_info["xedges"]
228+
)
229+
reconstructed_bins_limits_y = np.linspace(
230+
plotted_y_values.min(),
231+
plotted_y_values.max(),
232+
num=n_bins[1] + 1,
233+
endpoint=True,
234+
)
235+
assert np.allclose(
236+
reconstructed_bins_limits_y, histogram_info["yedges"]
237+
)
168238

169-
reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float)
170-
for i, xi in enumerate(reconstructed_bins_limits_x[:-1]):
171-
xi_p1 = reconstructed_bins_limits_x[i + 1]
239+
reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float)
240+
for i, xi in enumerate(reconstructed_bins_limits_x[:-1]):
241+
xi_p1 = reconstructed_bins_limits_x[i + 1]
172242

173-
x_pts_in_range = (plotted_x_values >= xi) & (plotted_x_values <= xi_p1)
174-
for j, yj in enumerate(reconstructed_bins_limits_y[:-1]):
175-
yj_p1 = reconstructed_bins_limits_y[j + 1]
176-
177-
y_pts_in_range = (plotted_y_values >= yj) & (
178-
plotted_y_values <= yj_p1
243+
x_pts_in_range = (plotted_x_values >= xi) & (
244+
plotted_x_values <= xi_p1
179245
)
246+
for j, yj in enumerate(reconstructed_bins_limits_y[:-1]):
247+
yj_p1 = reconstructed_bins_limits_y[j + 1]
248+
249+
y_pts_in_range = (plotted_y_values >= yj) & (
250+
plotted_y_values <= yj_p1
251+
)
180252

181-
pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum()
182-
reconstructed_bin_counts[i, j] = pts_in_this_bin
253+
pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum()
254+
reconstructed_bin_counts[i, j] = pts_in_this_bin
183255

184-
if pts_in_this_bin != plotted_values[i, j]:
185-
pass
256+
if pts_in_this_bin != plotted_values[i, j]:
257+
pass
186258

187-
# We agree with a manual count
188-
assert reconstructed_bin_counts.sum() == plotted_values.sum()
189-
# All non-NaN values were plotted
190-
assert n_non_nan_values == plotted_values.sum()
191-
# The counts were actually correct
192-
assert np.all(reconstructed_bin_counts == plotted_values)
259+
# We agree with a manual count
260+
assert reconstructed_bin_counts.sum() == plotted_values.sum()
261+
# All non-NaN values were plotted
262+
assert n_non_nan_values == plotted_values.sum()
263+
# The counts were actually correct
264+
assert np.all(reconstructed_bin_counts == plotted_values)
265+
else:
266+
# No non-nan values were given
267+
assert plotted_values.sum() == 0

0 commit comments

Comments
 (0)