|
| 1 | +from itertools import product |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import xarray as xr |
| 7 | +from matplotlib.collections import QuadMesh |
| 8 | +from numpy.random import RandomState |
| 9 | + |
| 10 | +from movement.plot import occupancy_histogram |
| 11 | + |
| 12 | + |
| 13 | +def get_histogram_binning_data(fig: plt.Figure) -> list[QuadMesh]: |
| 14 | + """Fetch 2D array data from a histogram plot.""" |
| 15 | + return [ |
| 16 | + qm for qm in fig.axes[0].get_children() if isinstance(qm, QuadMesh) |
| 17 | + ] |
| 18 | + |
| 19 | + |
| 20 | +@pytest.fixture |
| 21 | +def seed() -> int: |
| 22 | + return 0 |
| 23 | + |
| 24 | + |
| 25 | +@pytest.fixture(scope="function") |
| 26 | +def rng(seed: int) -> RandomState: |
| 27 | + """Create a RandomState to use in testing. |
| 28 | +
|
| 29 | + This ensures the repeatability of histogram tests, that require large |
| 30 | + datasets that would be tedious to create manually. |
| 31 | + """ |
| 32 | + return RandomState(seed) |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture |
| 36 | +def normal_dist_2d(rng: RandomState) -> np.ndarray: |
| 37 | + """Points distributed by the standard multivariate normal. |
| 38 | +
|
| 39 | + The standard multivariate normal is just two independent N(0, 1) |
| 40 | + distributions, one in each dimension. |
| 41 | + """ |
| 42 | + samples = rng.multivariate_normal( |
| 43 | + (0.0, 0.0), [[1.0, 0.0], [0.0, 1.0]], (250, 3, 4) |
| 44 | + ) |
| 45 | + return np.moveaxis( |
| 46 | + samples, 3, 1 |
| 47 | + ) # Move generated space coords to correct axis position |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture |
| 51 | +def histogram_data(normal_dist_2d: np.ndarray) -> xr.DataArray: |
| 52 | + """DataArray whose data is the ``normal_dist_2d`` points. |
| 53 | +
|
| 54 | + Axes 2 and 3 are the individuals and keypoints axes, respectively. |
| 55 | + These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for |
| 56 | + the purposes of indexing. |
| 57 | + """ |
| 58 | + return xr.DataArray( |
| 59 | + data=normal_dist_2d, |
| 60 | + dims=["time", "space", "individuals", "keypoints"], |
| 61 | + coords={ |
| 62 | + "space": ["x", "y"], |
| 63 | + "individuals": [f"i{i}" for i in range(normal_dist_2d.shape[2])], |
| 64 | + "keypoints": [f"k{i}" for i in range(normal_dist_2d.shape[3])], |
| 65 | + }, |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +@pytest.fixture |
| 70 | +def histogram_data_with_nans( |
| 71 | + histogram_data: xr.DataArray, rng: RandomState |
| 72 | +) -> xr.DataArray: |
| 73 | + """DataArray whose data is the ``normal_dist_2d`` points. |
| 74 | +
|
| 75 | + Each datapoint has a chance of being turned into a NaN value. |
| 76 | +
|
| 77 | + Axes 2 and 3 are the individuals and keypoints axes, respectively. |
| 78 | + These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for |
| 79 | + the purposes of indexing. |
| 80 | + """ |
| 81 | + data_with_nans = histogram_data.copy(deep=True) |
| 82 | + data_shape = data_with_nans.shape |
| 83 | + nan_chance = 1.0 / 25.0 |
| 84 | + index_ranges = [range(dim_length) for dim_length in data_shape] |
| 85 | + for multiindex in product(*index_ranges): |
| 86 | + if rng.uniform() < nan_chance: |
| 87 | + data_with_nans[*multiindex] = float("nan") |
| 88 | + return data_with_nans |
| 89 | + |
| 90 | + |
| 91 | +# def test_histogram_ignores_missing_dims( |
| 92 | +# input_does_not_have_dimensions: list[str], |
| 93 | +# ) -> None: |
| 94 | +# """Test that ``occupancy_histogram`` ignores non-present dimensions.""" |
| 95 | +# input_data = 0 |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.parametrize( |
| 99 | + ["data", "individual", "keypoint", "n_bins"], |
| 100 | + [pytest.param("histogram_data", "i0", "k0", 30, id="30 bins each axis")], |
| 101 | +) |
| 102 | +def test_occupancy_histogram( |
| 103 | + data: xr.DataArray, |
| 104 | + individual: int | str, |
| 105 | + keypoint: int | str, |
| 106 | + n_bins: int | tuple[int, int], |
| 107 | + request, |
| 108 | +) -> None: |
| 109 | + """Test that occupancy histograms correctly plot data.""" |
| 110 | + if isinstance(data, str): |
| 111 | + data = request.getfixturevalue(data) |
| 112 | + |
| 113 | + plotted_hist = occupancy_histogram( |
| 114 | + data, individual=individual, keypoint=keypoint, bins=n_bins |
| 115 | + ) |
| 116 | + |
| 117 | + # Confirm that a histogram was made |
| 118 | + plotted_data = get_histogram_binning_data(plotted_hist) |
| 119 | + assert len(plotted_data) == 1 |
| 120 | + plotted_data = plotted_data[0] |
| 121 | + plotting_coords = plotted_data.get_coordinates() |
| 122 | + plotted_values = plotted_data.get_array() |
| 123 | + |
| 124 | + # Confirm the binned array has the correct size |
| 125 | + if not isinstance(n_bins, tuple): |
| 126 | + n_bins = (n_bins, n_bins) |
| 127 | + assert plotted_data.get_array().shape == n_bins |
| 128 | + |
| 129 | + # Confirm that each bin has the correct number of assignments |
| 130 | + data_time_xy = data.sel(individuals=individual, keypoints=keypoint) |
| 131 | + x_values = data_time_xy.sel(space="x").values |
| 132 | + y_values = data_time_xy.sel(space="y").values |
| 133 | + reconstructed_bins_limits_x = np.linspace( |
| 134 | + x_values.min(), |
| 135 | + x_values.max(), |
| 136 | + num=n_bins[0] + 1, |
| 137 | + endpoint=True, |
| 138 | + ) |
| 139 | + assert all( |
| 140 | + np.allclose(reconstructed_bins_limits_x, plotting_coords[i, :, 0]) |
| 141 | + for i in range(n_bins[0]) |
| 142 | + ) |
| 143 | + reconstructed_bins_limits_y = np.linspace( |
| 144 | + y_values.min(), |
| 145 | + y_values.max(), |
| 146 | + num=n_bins[1] + 1, |
| 147 | + endpoint=True, |
| 148 | + ) |
| 149 | + assert all( |
| 150 | + np.allclose(reconstructed_bins_limits_y, plotting_coords[:, j, 1]) |
| 151 | + for j in range(n_bins[1]) |
| 152 | + ) |
| 153 | + |
| 154 | + reconstructed_bin_counts = np.zeros(shape=n_bins, dtype=float) |
| 155 | + for i, xi in enumerate(reconstructed_bins_limits_x[:-1]): |
| 156 | + xi_p1 = reconstructed_bins_limits_x[i + 1] |
| 157 | + |
| 158 | + x_pts_in_range = (x_values >= xi) & (x_values <= xi_p1) |
| 159 | + for j, yj in enumerate(reconstructed_bins_limits_y[:-1]): |
| 160 | + yj_p1 = reconstructed_bins_limits_y[j + 1] |
| 161 | + |
| 162 | + y_pts_in_range = (y_values >= yj) & (y_values <= yj_p1) |
| 163 | + |
| 164 | + pts_in_this_bin = (x_pts_in_range & y_pts_in_range).sum() |
| 165 | + reconstructed_bin_counts[i, j] = pts_in_this_bin |
| 166 | + |
| 167 | + if pts_in_this_bin != plotted_values[i, j]: |
| 168 | + pass |
| 169 | + |
| 170 | + assert reconstructed_bin_counts.sum() == plotted_values.sum() |
| 171 | + assert np.all(reconstructed_bin_counts == plotted_values) |
0 commit comments