Skip to content

Commit a9b9a0f

Browse files
committed
Additional return values to help extract histogram information
1 parent 9a4614c commit a9b9a0f

File tree

2 files changed

+107
-50
lines changed

2 files changed

+107
-50
lines changed

movement/plot.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Wrappers to plot movement data."""
22

3-
from typing import Any
3+
from typing import Any, Literal, TypeAlias
44

55
import matplotlib.pyplot as plt
6+
import numpy as np
67
import xarray as xr
78

9+
HistInfoKeys: TypeAlias = Literal["counts", "xedges", "yedges"]
10+
811
DEFAULT_HIST_ARGS = {"alpha": 1.0, "bins": 30, "cmap": "viridis"}
912

1013

@@ -14,9 +17,13 @@ def occupancy_histogram(
1417
individual: int | str = 0,
1518
title: str | None = None,
1619
**kwargs: Any,
17-
) -> plt.Figure:
20+
) -> tuple[plt.Figure, dict[HistInfoKeys, np.ndarray]]:
1821
"""Create a 2D histogram of the occupancy data given.
1922
23+
Time-points whose corresponding spatial coordinates have NaN values
24+
are ignored. Histogram information is returned as the second output
25+
value (see Notes).
26+
2027
Parameters
2128
----------
2229
da : xarray.DataArray
@@ -35,6 +42,35 @@ def occupancy_histogram(
3542
-------
3643
matplotlib.pyplot.Figure
3744
Plot handle containing the rendered 2D histogram.
45+
dict[str, numpy.ndarray]
46+
Information about the created histogram (see Notes).
47+
48+
Notes
49+
-----
50+
In instances where the counts or information about the histogram bins is
51+
desired, the ``return_hist_info`` argument should be provided as ``True``.
52+
This will force the function to return a second output value, which is a
53+
dictionary containing the bin edges and bin counts that were used to create
54+
the histogram.
55+
56+
For data with ``N`` time-points, the dictionary output has key-value pairs;
57+
- ``xedges``, an ``(N+1,)`` ``numpy`` array specifying the bin edges in the
58+
first spatial dimension.
59+
- ``yedges``, same as ``xedges`` but for the second spatial dimension.
60+
- ``counts``, an ``(N, N)`` ``numpy`` array with the count for each bin.
61+
62+
``counts[x, y]`` is the number of datapoints in the
63+
``(xedges[x], xedges[x+1]), (yedges[y], yedges[y+1])`` bin. These values
64+
are those returned from ``matplotlib.pyplot.Axes.hist2d``.
65+
66+
Note that the ``counts`` values do not necessarily match the mappable
67+
values that one gets from extracting the data from the
68+
``matplotlib.collections.QuadMesh`` object (that represents the rendered
69+
histogram) via its ``get_array()`` attribute.
70+
71+
See Also
72+
--------
73+
matplotlib.pyplot.Axes.hist2d : The underlying plotting function.
3874
3975
"""
4076
data = da.position if isinstance(da, xr.Dataset) else da
@@ -64,9 +100,9 @@ def occupancy_histogram(
64100
kwargs[key] = value
65101
# Now it should just be a case of creating the histogram
66102
fig, ax = plt.subplots()
67-
_, _, _, hist_image = ax.hist2d(
103+
counts, xedges, yedges, hist_image = ax.hist2d(
68104
data.sel(space=x_coord), data.sel(space=y_coord), **kwargs
69-
) # counts, xedges, yedges, image
105+
)
70106
colourbar = fig.colorbar(hist_image, ax=ax)
71107
colourbar.solids.set(alpha=1.0)
72108

@@ -78,4 +114,4 @@ def occupancy_histogram(
78114
ax.set_xlabel(x_coord)
79115
ax.set_ylabel(y_coord)
80116

81-
return fig
117+
return fig, {"counts": counts, "xedges": xedges, "yedges": yedges}

tests/test_unit/test_plot.py

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
from itertools import product
2-
3-
import matplotlib.pyplot as plt
41
import numpy as np
52
import pytest
63
import xarray as xr
7-
from matplotlib.collections import QuadMesh
84
from numpy.random import RandomState
95

106
from movement.plot import occupancy_histogram
117

128

13-
def get_histogram_binning_data(fig: plt.Figure) -> list[QuadMesh]:
14-
"""Fetch 2D array data from a histogram plot."""
15-
return [
16-
qm for qm in fig.axes[0].get_children() if isinstance(qm, QuadMesh)
17-
]
18-
19-
209
@pytest.fixture
2110
def seed() -> int:
2211
return 0
@@ -72,19 +61,30 @@ def histogram_data_with_nans(
7261
) -> xr.DataArray:
7362
"""DataArray whose data is the ``normal_dist_2d`` points.
7463
75-
Each datapoint has a chance of being turned into a NaN value.
76-
7764
Axes 2 and 3 are the individuals and keypoints axes, respectively.
7865
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
7966
the purposes of indexing.
67+
68+
For individual i0, keypoint k0, the following (time, space) values are
69+
converted into NaNs:
70+
- (100, "x")
71+
- (200, "y")
72+
- (150, "x")
73+
- (150, "y")
74+
8075
"""
76+
individual_0 = "i0"
77+
keypoint_0 = "k0"
8178
data_with_nans = histogram_data.copy(deep=True)
82-
data_shape = data_with_nans.shape
83-
nan_chance = 1.0 / 25.0
84-
index_ranges = [range(dim_length) for dim_length in data_shape]
85-
for multiindex in product(*index_ranges):
86-
if rng.uniform() < nan_chance:
87-
data_with_nans[*multiindex] = float("nan")
79+
for time_index, space_coord in [
80+
(100, "x"),
81+
(200, "y"),
82+
(150, "x"),
83+
(150, "y"),
84+
]:
85+
data_with_nans.loc[
86+
time_index, space_coord, individual_0, keypoint_0
87+
] = float("nan")
8888
return data_with_nans
8989

9090

@@ -97,7 +97,29 @@ def histogram_data_with_nans(
9797

9898
@pytest.mark.parametrize(
9999
["data", "individual", "keypoint", "n_bins"],
100-
[pytest.param("histogram_data", "i0", "k0", 30, id="30 bins each axis")],
100+
[
101+
pytest.param(
102+
"histogram_data",
103+
"i0",
104+
"k0",
105+
30,
106+
id="30 bins each axis",
107+
),
108+
pytest.param(
109+
"histogram_data",
110+
"i1",
111+
"k0",
112+
(20, 30),
113+
id="(20, 30) bins",
114+
),
115+
pytest.param(
116+
"histogram_data_with_nans",
117+
"i0",
118+
"k0",
119+
30,
120+
id="NaNs should be removed",
121+
),
122+
],
101123
)
102124
def test_occupancy_histogram(
103125
data: xr.DataArray,
@@ -110,62 +132,61 @@ def test_occupancy_histogram(
110132
if isinstance(data, str):
111133
data = request.getfixturevalue(data)
112134

113-
plotted_hist = occupancy_histogram(
135+
_, histogram_info = occupancy_histogram(
114136
data, individual=individual, keypoint=keypoint, bins=n_bins
115137
)
116-
117-
# Confirm that a histogram was made
118-
plotted_data = get_histogram_binning_data(plotted_hist)
119-
assert len(plotted_data) == 1
120-
plotted_data = plotted_data[0]
121-
plotting_coords = plotted_data.get_coordinates()
122-
plotted_values = plotted_data.get_array()
138+
plotted_values = histogram_info["counts"]
123139

124140
# Confirm the binned array has the correct size
125141
if not isinstance(n_bins, tuple):
126142
n_bins = (n_bins, n_bins)
127-
assert plotted_data.get_array().shape == n_bins
143+
assert plotted_values.shape == n_bins
128144

129145
# Confirm that each bin has the correct number of assignments
130146
data_time_xy = data.sel(individuals=individual, keypoints=keypoint)
131-
x_values = data_time_xy.sel(space="x").values
132-
y_values = data_time_xy.sel(space="y").values
147+
data_time_xy = data_time_xy.dropna(dim="time", how="any")
148+
plotted_x_values = data_time_xy.sel(space="x").values
149+
plotted_y_values = data_time_xy.sel(space="y").values
150+
assert plotted_x_values.shape == plotted_y_values.shape
151+
# This many non-NaN values were plotted
152+
n_non_nan_values = plotted_x_values.shape[0]
153+
133154
reconstructed_bins_limits_x = np.linspace(
134-
x_values.min(),
135-
x_values.max(),
155+
plotted_x_values.min(),
156+
plotted_x_values.max(),
136157
num=n_bins[0] + 1,
137158
endpoint=True,
138159
)
139-
assert all(
140-
np.allclose(reconstructed_bins_limits_x, plotting_coords[i, :, 0])
141-
for i in range(n_bins[0])
142-
)
160+
assert np.allclose(reconstructed_bins_limits_x, histogram_info["xedges"])
143161
reconstructed_bins_limits_y = np.linspace(
144-
y_values.min(),
145-
y_values.max(),
162+
plotted_y_values.min(),
163+
plotted_y_values.max(),
146164
num=n_bins[1] + 1,
147165
endpoint=True,
148166
)
149-
assert all(
150-
np.allclose(reconstructed_bins_limits_y, plotting_coords[:, j, 1])
151-
for j in range(n_bins[1])
152-
)
167+
assert np.allclose(reconstructed_bins_limits_y, histogram_info["yedges"])
153168

154169
reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float)
155170
for i, xi in enumerate(reconstructed_bins_limits_x[:-1]):
156171
xi_p1 = reconstructed_bins_limits_x[i + 1]
157172

158-
x_pts_in_range = (x_values >= xi) & (x_values <= xi_p1)
173+
x_pts_in_range = (plotted_x_values >= xi) & (plotted_x_values <= xi_p1)
159174
for j, yj in enumerate(reconstructed_bins_limits_y[:-1]):
160175
yj_p1 = reconstructed_bins_limits_y[j + 1]
161176

162-
y_pts_in_range = (y_values >= yj) & (y_values <= yj_p1)
177+
y_pts_in_range = (plotted_y_values >= yj) & (
178+
plotted_y_values <= yj_p1
179+
)
163180

164181
pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum()
165182
reconstructed_bin_counts[i, j] = pts_in_this_bin
166183

167184
if pts_in_this_bin != plotted_values[i, j]:
168185
pass
169186

187+
# We agree with a manual count
170188
assert reconstructed_bin_counts.sum() == plotted_values.sum()
189+
# All non-NaN values were plotted
190+
assert n_non_nan_values == plotted_values.sum()
191+
# The counts were actually correct
171192
assert np.all(reconstructed_bin_counts == plotted_values)

0 commit comments

Comments
 (0)