Skip to content

Commit a5ae259

Browse files
committed
Check that new / existing axes are respected
1 parent 4fa1514 commit a5ae259

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

movement/plot.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ def occupancy_histogram(
1616
keypoint: int | str = 0,
1717
individual: int | str = 0,
1818
title: str | None = None,
19+
ax: plt.Axes | None = None,
1920
**kwargs: Any,
20-
) -> tuple[plt.Figure, dict[HistInfoKeys, np.ndarray]]:
21+
) -> tuple[plt.Figure, plt.Axes, dict[HistInfoKeys, np.ndarray]]:
2122
"""Create a 2D histogram of the occupancy data given.
2223
2324
Time-points whose corresponding spatial coordinates have NaN values
@@ -35,13 +36,20 @@ def occupancy_histogram(
3536
title : str, optional
3637
Title to give to the plot. Default will be generated from the
3738
``keypoint`` and ``individual``
39+
ax : matplotlib.axes.Axes, optional
40+
Axes object on which to draw the histogram. If not provided, a new
41+
figure and axes are created and returned.
3842
kwargs : Any
3943
Keyword arguments passed to ``matplotlib.pyplot.hist2d``
4044
4145
Returns
4246
-------
4347
matplotlib.pyplot.Figure
44-
Plot handle containing the rendered 2D histogram.
48+
Plot handle containing the rendered 2D histogram. If ``ax`` is
49+
supplied, this will be the figure that ``ax`` belongs to.
50+
matplotlib.axes.Axes
51+
Axes on which the histogram was drawn. If ``ax`` was supplied,
52+
the input will be directly modified and returned in this value.
4553
dict[str, numpy.ndarray]
4654
Information about the created histogram (see Notes).
4755
@@ -99,7 +107,10 @@ def occupancy_histogram(
99107
if key not in kwargs:
100108
kwargs[key] = value
101109
# Now it should just be a case of creating the histogram
102-
fig, ax = plt.subplots()
110+
if ax is not None:
111+
fig = ax.get_figure()
112+
else:
113+
fig, ax = plt.subplots()
103114
counts, xedges, yedges, hist_image = ax.hist2d(
104115
data.sel(space=x_coord), data.sel(space=y_coord), **kwargs
105116
)
@@ -114,4 +125,4 @@ def occupancy_histogram(
114125
ax.set_xlabel(x_coord)
115126
ax.set_ylabel(y_coord)
116127

117-
return fig, {"counts": counts, "xedges": xedges, "yedges": yedges}
128+
return fig, ax, {"counts": counts, "xedges": xedges, "yedges": yedges}

tests/test_unit/test_plot.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import matplotlib.pyplot as plt
12
import numpy as np
23
import pytest
34
import xarray as xr
5+
from matplotlib.collections import QuadMesh
46
from numpy.random import RandomState
57

68
from movement.plot import occupancy_histogram
@@ -197,7 +199,7 @@ def test_occupancy_histogram(
197199
# when examining the xy data later in the test.
198200
kwargs_to_select_xy_data.pop(d, None)
199201

200-
_, histogram_info = occupancy_histogram(
202+
_, _, histogram_info = occupancy_histogram(
201203
data, individual=individual, keypoint=keypoint, bins=n_bins
202204
)
203205
plotted_values = histogram_info["counts"]
@@ -265,3 +267,35 @@ def test_occupancy_histogram(
265267
else:
266268
# No non-nan values were given
267269
assert plotted_values.sum() == 0
270+
271+
272+
def test_respects_axes(histogram_data: xr.DataArray) -> None:
273+
"""Check that existing axes objects are respected if passed."""
274+
# Plotting on existing axes
275+
existing_fig, existing_ax = plt.subplots(1, 2)
276+
277+
existing_ax[0].plot(
278+
np.linspace(0.0, 10.0, num=100), np.linspace(0.0, 10.0, num=100)
279+
)
280+
281+
_, _, hist_info_existing = occupancy_histogram(
282+
histogram_data, ax=existing_ax[1]
283+
)
284+
hist_plots_added = [
285+
qm for qm in existing_ax[1].get_children() if isinstance(qm, QuadMesh)
286+
]
287+
assert len(hist_plots_added) == 1
288+
289+
# Plot on new axis and create a new figure
290+
new_fig, new_ax, hist_info_new = occupancy_histogram(histogram_data)
291+
hist_plots_created = [
292+
qm for qm in new_ax.get_children() if isinstance(qm, QuadMesh)
293+
]
294+
assert len(hist_plots_created) == 1
295+
296+
# Check that the same plot was made for each
297+
assert set(hist_info_new.keys()) == set(hist_info_existing.keys())
298+
for key, new_ax_value in hist_info_new.items():
299+
existing_ax_value = hist_info_existing[key]
300+
301+
assert np.allclose(new_ax_value, existing_ax_value)

0 commit comments

Comments
 (0)