@@ -88,50 +88,115 @@ def histogram_data_with_nans(
88
88
return data_with_nans
89
89
90
90
91
- # def test_histogram_ignores_missing_dims(
92
- # input_does_not_have_dimensions: list[str],
93
- # ) -> None:
94
- # """Test that ``occupancy_histogram`` ignores non-present dimensions."""
95
- # input_data = 0
91
+ @ pytest . fixture
92
+ def entirely_nan_data ( histogram_data : xr . DataArray ) -> xr . DataArray :
93
+ return histogram_data . copy (
94
+ deep = True , data = histogram_data . values * float ( "nan" )
95
+ )
96
96
97
97
98
98
@pytest .mark .parametrize (
99
- ["data" , "individual" , "keypoint" , "n_bins" ],
99
+ [
100
+ "data" ,
101
+ "remove_dims_from_data_before_starting" ,
102
+ "individual" ,
103
+ "keypoint" ,
104
+ "n_bins" ,
105
+ ],
100
106
[
101
107
pytest .param (
102
108
"histogram_data" ,
109
+ [],
103
110
"i0" ,
104
111
"k0" ,
105
112
30 ,
106
113
id = "30 bins each axis" ,
107
114
),
108
115
pytest .param (
109
116
"histogram_data" ,
117
+ [],
110
118
"i1" ,
111
119
"k0" ,
112
120
(20 , 30 ),
113
121
id = "(20, 30) bins" ,
114
122
),
115
123
pytest .param (
116
124
"histogram_data_with_nans" ,
125
+ [],
117
126
"i0" ,
118
127
"k0" ,
119
128
30 ,
120
129
id = "NaNs should be removed" ,
121
130
),
131
+ pytest .param (
132
+ "entirely_nan_data" ,
133
+ [],
134
+ "i0" ,
135
+ "k0" ,
136
+ 10 ,
137
+ id = "All NaN-data" ,
138
+ ),
139
+ pytest .param (
140
+ "histogram_data" ,
141
+ ["individuals" ],
142
+ "i0" ,
143
+ "k0" ,
144
+ 30 ,
145
+ id = "Ignores individual if not a dimension" ,
146
+ ),
147
+ pytest .param (
148
+ "histogram_data" ,
149
+ ["keypoints" ],
150
+ "i0" ,
151
+ "k1" ,
152
+ 30 ,
153
+ id = "Ignores keypoint if not a dimension" ,
154
+ ),
155
+ pytest .param (
156
+ "histogram_data" ,
157
+ ["individuals" , "keypoints" ],
158
+ "i0" ,
159
+ "k0" ,
160
+ 30 ,
161
+ id = "Can handle raw xy data" ,
162
+ ),
122
163
],
123
164
)
124
165
def test_occupancy_histogram (
125
166
data : xr .DataArray ,
167
+ remove_dims_from_data_before_starting : list [str ],
126
168
individual : int | str ,
127
169
keypoint : int | str ,
128
170
n_bins : int | tuple [int , int ],
129
171
request ,
130
172
) -> None :
131
- """Test that occupancy histograms correctly plot data."""
173
+ """Test that occupancy histograms correctly plot data.
174
+
175
+ Specifically, check that:
176
+ - The bin edges are what we expect.
177
+ - The bin counts can be manually verified and are in agreement.
178
+ - Only non-NaN values are plotted, but NaN values do not throw errors.
179
+ """
132
180
if isinstance (data , str ):
133
181
data = request .getfixturevalue (data )
134
182
183
+ # We will need to only select the xy data later in the test,
184
+ # but if we are dropping dimensions we might need to call it
185
+ # in different ways.
186
+ kwargs_to_select_xy_data = {
187
+ "individuals" : individual ,
188
+ "keypoints" : keypoint ,
189
+ }
190
+ for d in remove_dims_from_data_before_starting :
191
+ # Retain the 0th value in the corresponding dimension,
192
+ # then drop that dimension.
193
+ data = data .sel ({d : getattr (data , d )[0 ]}).squeeze ()
194
+ assert d not in data .dims
195
+
196
+ # We no longer need to filter this dimension out
197
+ # when examining the xy data later in the test.
198
+ kwargs_to_select_xy_data .pop (d , None )
199
+
135
200
_ , histogram_info = occupancy_histogram (
136
201
data , individual = individual , keypoint = keypoint , bins = n_bins
137
202
)
@@ -143,50 +208,60 @@ def test_occupancy_histogram(
143
208
assert plotted_values .shape == n_bins
144
209
145
210
# Confirm that each bin has the correct number of assignments
146
- data_time_xy = data .sel (individuals = individual , keypoints = keypoint )
211
+ data_time_xy = data .sel (** kwargs_to_select_xy_data )
147
212
data_time_xy = data_time_xy .dropna (dim = "time" , how = "any" )
148
213
plotted_x_values = data_time_xy .sel (space = "x" ).values
149
214
plotted_y_values = data_time_xy .sel (space = "y" ).values
150
215
assert plotted_x_values .shape == plotted_y_values .shape
151
216
# This many non-NaN values were plotted
152
217
n_non_nan_values = plotted_x_values .shape [0 ]
153
218
154
- reconstructed_bins_limits_x = np .linspace (
155
- plotted_x_values .min (),
156
- plotted_x_values .max (),
157
- num = n_bins [0 ] + 1 ,
158
- endpoint = True ,
159
- )
160
- assert np .allclose (reconstructed_bins_limits_x , histogram_info ["xedges" ])
161
- reconstructed_bins_limits_y = np .linspace (
162
- plotted_y_values .min (),
163
- plotted_y_values .max (),
164
- num = n_bins [1 ] + 1 ,
165
- endpoint = True ,
166
- )
167
- assert np .allclose (reconstructed_bins_limits_y , histogram_info ["yedges" ])
219
+ if n_non_nan_values > 0 :
220
+ reconstructed_bins_limits_x = np .linspace (
221
+ plotted_x_values .min (),
222
+ plotted_x_values .max (),
223
+ num = n_bins [0 ] + 1 ,
224
+ endpoint = True ,
225
+ )
226
+ assert np .allclose (
227
+ reconstructed_bins_limits_x , histogram_info ["xedges" ]
228
+ )
229
+ reconstructed_bins_limits_y = np .linspace (
230
+ plotted_y_values .min (),
231
+ plotted_y_values .max (),
232
+ num = n_bins [1 ] + 1 ,
233
+ endpoint = True ,
234
+ )
235
+ assert np .allclose (
236
+ reconstructed_bins_limits_y , histogram_info ["yedges" ]
237
+ )
168
238
169
- reconstructed_bin_counts = np .zeros (shape = n_bins , dtype = float )
170
- for i , xi in enumerate (reconstructed_bins_limits_x [:- 1 ]):
171
- xi_p1 = reconstructed_bins_limits_x [i + 1 ]
239
+ reconstructed_bin_counts = np .zeros (shape = n_bins , dtype = float )
240
+ for i , xi in enumerate (reconstructed_bins_limits_x [:- 1 ]):
241
+ xi_p1 = reconstructed_bins_limits_x [i + 1 ]
172
242
173
- x_pts_in_range = (plotted_x_values >= xi ) & (plotted_x_values <= xi_p1 )
174
- for j , yj in enumerate (reconstructed_bins_limits_y [:- 1 ]):
175
- yj_p1 = reconstructed_bins_limits_y [j + 1 ]
176
-
177
- y_pts_in_range = (plotted_y_values >= yj ) & (
178
- plotted_y_values <= yj_p1
243
+ x_pts_in_range = (plotted_x_values >= xi ) & (
244
+ plotted_x_values <= xi_p1
179
245
)
246
+ for j , yj in enumerate (reconstructed_bins_limits_y [:- 1 ]):
247
+ yj_p1 = reconstructed_bins_limits_y [j + 1 ]
248
+
249
+ y_pts_in_range = (plotted_y_values >= yj ) & (
250
+ plotted_y_values <= yj_p1
251
+ )
180
252
181
- pts_in_this_bin = (x_pts_in_range & y_pts_in_range ).sum ()
182
- reconstructed_bin_counts [i , j ] = pts_in_this_bin
253
+ pts_in_this_bin = (x_pts_in_range & y_pts_in_range ).sum ()
254
+ reconstructed_bin_counts [i , j ] = pts_in_this_bin
183
255
184
- if pts_in_this_bin != plotted_values [i , j ]:
185
- pass
256
+ if pts_in_this_bin != plotted_values [i , j ]:
257
+ pass
186
258
187
- # We agree with a manual count
188
- assert reconstructed_bin_counts .sum () == plotted_values .sum ()
189
- # All non-NaN values were plotted
190
- assert n_non_nan_values == plotted_values .sum ()
191
- # The counts were actually correct
192
- assert np .all (reconstructed_bin_counts == plotted_values )
259
+ # We agree with a manual count
260
+ assert reconstructed_bin_counts .sum () == plotted_values .sum ()
261
+ # All non-NaN values were plotted
262
+ assert n_non_nan_values == plotted_values .sum ()
263
+ # The counts were actually correct
264
+ assert np .all (reconstructed_bin_counts == plotted_values )
265
+ else :
266
+ # No non-nan values were given
267
+ assert plotted_values .sum () == 0
0 commit comments