1
- from itertools import product
2
-
3
- import matplotlib .pyplot as plt
4
1
import numpy as np
5
2
import pytest
6
3
import xarray as xr
7
- from matplotlib .collections import QuadMesh
8
4
from numpy .random import RandomState
9
5
10
6
from movement .plot import occupancy_histogram
11
7
12
8
13
- def get_histogram_binning_data (fig : plt .Figure ) -> list [QuadMesh ]:
14
- """Fetch 2D array data from a histogram plot."""
15
- return [
16
- qm for qm in fig .axes [0 ].get_children () if isinstance (qm , QuadMesh )
17
- ]
18
-
19
-
20
9
@pytest .fixture
21
10
def seed () -> int :
22
11
return 0
@@ -72,19 +61,30 @@ def histogram_data_with_nans(
72
61
) -> xr .DataArray :
73
62
"""DataArray whose data is the ``normal_dist_2d`` points.
74
63
75
- Each datapoint has a chance of being turned into a NaN value.
76
-
77
64
Axes 2 and 3 are the individuals and keypoints axes, respectively.
78
65
These dimensions are given coordinates {i,k}{0,1,2,3,4,5,...} for
79
66
the purposes of indexing.
67
+
68
+ For individual i0, keypoint k0, the following (time, space) values are
69
+ converted into NaNs:
70
+ - (100, "x")
71
+ - (200, "y")
72
+ - (150, "x")
73
+ - (150, "y")
74
+
80
75
"""
76
+ individual_0 = "i0"
77
+ keypoint_0 = "k0"
81
78
data_with_nans = histogram_data .copy (deep = True )
82
- data_shape = data_with_nans .shape
83
- nan_chance = 1.0 / 25.0
84
- index_ranges = [range (dim_length ) for dim_length in data_shape ]
85
- for multiindex in product (* index_ranges ):
86
- if rng .uniform () < nan_chance :
87
- data_with_nans [* multiindex ] = float ("nan" )
79
+ for time_index , space_coord in [
80
+ (100 , "x" ),
81
+ (200 , "y" ),
82
+ (150 , "x" ),
83
+ (150 , "y" ),
84
+ ]:
85
+ data_with_nans .loc [
86
+ time_index , space_coord , individual_0 , keypoint_0
87
+ ] = float ("nan" )
88
88
return data_with_nans
89
89
90
90
@@ -97,7 +97,29 @@ def histogram_data_with_nans(
97
97
98
98
@pytest .mark .parametrize (
99
99
["data" , "individual" , "keypoint" , "n_bins" ],
100
- [pytest .param ("histogram_data" , "i0" , "k0" , 30 , id = "30 bins each axis" )],
100
+ [
101
+ pytest .param (
102
+ "histogram_data" ,
103
+ "i0" ,
104
+ "k0" ,
105
+ 30 ,
106
+ id = "30 bins each axis" ,
107
+ ),
108
+ pytest .param (
109
+ "histogram_data" ,
110
+ "i1" ,
111
+ "k0" ,
112
+ (20 , 30 ),
113
+ id = "(20, 30) bins" ,
114
+ ),
115
+ pytest .param (
116
+ "histogram_data_with_nans" ,
117
+ "i0" ,
118
+ "k0" ,
119
+ 30 ,
120
+ id = "NaNs should be removed" ,
121
+ ),
122
+ ],
101
123
)
102
124
def test_occupancy_histogram (
103
125
data : xr .DataArray ,
@@ -110,62 +132,61 @@ def test_occupancy_histogram(
110
132
if isinstance (data , str ):
111
133
data = request .getfixturevalue (data )
112
134
113
- plotted_hist = occupancy_histogram (
135
+ _ , histogram_info = occupancy_histogram (
114
136
data , individual = individual , keypoint = keypoint , bins = n_bins
115
137
)
116
-
117
- # Confirm that a histogram was made
118
- plotted_data = get_histogram_binning_data (plotted_hist )
119
- assert len (plotted_data ) == 1
120
- plotted_data = plotted_data [0 ]
121
- plotting_coords = plotted_data .get_coordinates ()
122
- plotted_values = plotted_data .get_array ()
138
+ plotted_values = histogram_info ["counts" ]
123
139
124
140
# Confirm the binned array has the correct size
125
141
if not isinstance (n_bins , tuple ):
126
142
n_bins = (n_bins , n_bins )
127
- assert plotted_data . get_array () .shape == n_bins
143
+ assert plotted_values .shape == n_bins
128
144
129
145
# Confirm that each bin has the correct number of assignments
130
146
data_time_xy = data .sel (individuals = individual , keypoints = keypoint )
131
- x_values = data_time_xy .sel (space = "x" ).values
132
- y_values = data_time_xy .sel (space = "y" ).values
147
+ data_time_xy = data_time_xy .dropna (dim = "time" , how = "any" )
148
+ plotted_x_values = data_time_xy .sel (space = "x" ).values
149
+ plotted_y_values = data_time_xy .sel (space = "y" ).values
150
+ assert plotted_x_values .shape == plotted_y_values .shape
151
+ # This many non-NaN values were plotted
152
+ n_non_nan_values = plotted_x_values .shape [0 ]
153
+
133
154
reconstructed_bins_limits_x = np .linspace (
134
- x_values .min (),
135
- x_values .max (),
155
+ plotted_x_values .min (),
156
+ plotted_x_values .max (),
136
157
num = n_bins [0 ] + 1 ,
137
158
endpoint = True ,
138
159
)
139
- assert all (
140
- np .allclose (reconstructed_bins_limits_x , plotting_coords [i , :, 0 ])
141
- for i in range (n_bins [0 ])
142
- )
160
+ assert np .allclose (reconstructed_bins_limits_x , histogram_info ["xedges" ])
143
161
reconstructed_bins_limits_y = np .linspace (
144
- y_values .min (),
145
- y_values .max (),
162
+ plotted_y_values .min (),
163
+ plotted_y_values .max (),
146
164
num = n_bins [1 ] + 1 ,
147
165
endpoint = True ,
148
166
)
149
- assert all (
150
- np .allclose (reconstructed_bins_limits_y , plotting_coords [:, j , 1 ])
151
- for j in range (n_bins [1 ])
152
- )
167
+ assert np .allclose (reconstructed_bins_limits_y , histogram_info ["yedges" ])
153
168
154
169
reconstructed_bin_counts = np .zeros (shape = n_bins , dtype = float )
155
170
for i , xi in enumerate (reconstructed_bins_limits_x [:- 1 ]):
156
171
xi_p1 = reconstructed_bins_limits_x [i + 1 ]
157
172
158
- x_pts_in_range = (x_values >= xi ) & (x_values <= xi_p1 )
173
+ x_pts_in_range = (plotted_x_values >= xi ) & (plotted_x_values <= xi_p1 )
159
174
for j , yj in enumerate (reconstructed_bins_limits_y [:- 1 ]):
160
175
yj_p1 = reconstructed_bins_limits_y [j + 1 ]
161
176
162
- y_pts_in_range = (y_values >= yj ) & (y_values <= yj_p1 )
177
+ y_pts_in_range = (plotted_y_values >= yj ) & (
178
+ plotted_y_values <= yj_p1
179
+ )
163
180
164
181
pts_in_this_bin = (x_pts_in_range & y_pts_in_range ).sum ()
165
182
reconstructed_bin_counts [i , j ] = pts_in_this_bin
166
183
167
184
if pts_in_this_bin != plotted_values [i , j ]:
168
185
pass
169
186
187
+ # We agree with a manual count
170
188
assert reconstructed_bin_counts .sum () == plotted_values .sum ()
189
+ # All non-NaN values were plotted
190
+ assert n_non_nan_values == plotted_values .sum ()
191
+ # The counts were actually correct
171
192
assert np .all (reconstructed_bin_counts == plotted_values )
0 commit comments