@@ -111,6 +111,86 @@ def metadata(self) -> dict:
111
111
"missing_prediction_labels" : self .missing_prediction_labels ,
112
112
}
113
113
114
+ def create_filter (
115
+ self ,
116
+ datum_uids : list [str ] | NDArray [np .int32 ] | None = None ,
117
+ labels : list [str ] | NDArray [np .int32 ] | None = None ,
118
+ ) -> Filter :
119
+ """
120
+ Creates a boolean mask that can be passed to an evaluation.
121
+
122
+ Parameters
123
+ ----------
124
+ datum_uids : list[str] | NDArray[np.int32], optional
125
+ An optional list of string uids or a numpy array of uid indices.
126
+ labels : list[str] | NDArray[np.int32], optional
127
+ An optional list of labels or a numpy array of label indices.
128
+
129
+ Returns
130
+ -------
131
+ Filter
132
+ A filter object that can be passed to the `evaluate` method.
133
+ """
134
+ n_rows = self ._detailed_pairs .shape [0 ]
135
+
136
+ n_datums = self ._label_metadata_per_datum .shape [1 ]
137
+ n_labels = self ._label_metadata_per_datum .shape [2 ]
138
+
139
+ mask_pairs = np .ones ((n_rows , 1 ), dtype = np .bool_ )
140
+ mask_datums = np .ones (n_datums , dtype = np .bool_ )
141
+ mask_labels = np .ones (n_labels , dtype = np .bool_ )
142
+
143
+ if datum_uids is not None :
144
+ if isinstance (datum_uids , list ):
145
+ datum_uids = np .array (
146
+ [self .uid_to_index [uid ] for uid in datum_uids ],
147
+ dtype = np .int32 ,
148
+ )
149
+ mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
150
+ mask [
151
+ np .isin (self ._detailed_pairs [:, 0 ].astype (int ), datum_uids )
152
+ ] = True
153
+ mask_pairs &= mask
154
+
155
+ mask = np .zeros_like (mask_datums , dtype = np .bool_ )
156
+ mask [datum_uids ] = True
157
+ mask_datums &= mask
158
+
159
+ if labels is not None :
160
+ if isinstance (labels , list ):
161
+ labels = np .array (
162
+ [self .label_to_index [label ] for label in labels ]
163
+ )
164
+ mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
165
+ mask [
166
+ np .isin (self ._detailed_pairs [:, 1 ].astype (int ), labels )
167
+ ] = True
168
+ mask_pairs &= mask
169
+
170
+ mask = np .zeros_like (mask_labels , dtype = np .bool_ )
171
+ mask [labels ] = True
172
+ mask_labels &= mask
173
+
174
+ mask = mask_datums [:, np .newaxis ] & mask_labels [np .newaxis , :]
175
+ label_metadata_per_datum = self ._label_metadata_per_datum .copy ()
176
+ label_metadata_per_datum [:, ~ mask ] = 0
177
+
178
+ label_metadata = np .zeros_like (self ._label_metadata , dtype = np .int32 )
179
+ label_metadata = np .transpose (
180
+ np .sum (
181
+ label_metadata_per_datum ,
182
+ axis = 1 ,
183
+ )
184
+ )
185
+
186
+ n_datums = int (np .sum (label_metadata [:, 0 ]))
187
+
188
+ return Filter (
189
+ indices = np .where (mask_pairs )[0 ],
190
+ label_metadata = label_metadata ,
191
+ n_datums = n_datums ,
192
+ )
193
+
114
194
def _unpack_confusion_matrix (
115
195
self ,
116
196
confusion_matrix : NDArray [np .float64 ],
@@ -218,86 +298,6 @@ def _unpack_missing_predictions(
218
298
for gt_label_idx in range (number_of_labels )
219
299
}
220
300
221
- def create_filter (
222
- self ,
223
- datum_uids : list [str ] | NDArray [np .int32 ] | None = None ,
224
- labels : list [str ] | NDArray [np .int32 ] | None = None ,
225
- ) -> Filter :
226
- """
227
- Creates a boolean mask that can be passed to an evaluation.
228
-
229
- Parameters
230
- ----------
231
- datum_uids : list[str] | NDArray[np.int32], optional
232
- An optional list of string uids or a numpy array of uid indices.
233
- labels : list[str] | NDArray[np.int32], optional
234
- An optional list of labels or a numpy array of label indices.
235
-
236
- Returns
237
- -------
238
- Filter
239
- A filter object that can be passed to the `evaluate` method.
240
- """
241
- n_rows = self ._detailed_pairs .shape [0 ]
242
-
243
- n_datums = self ._label_metadata_per_datum .shape [1 ]
244
- n_labels = self ._label_metadata_per_datum .shape [2 ]
245
-
246
- mask_pairs = np .ones ((n_rows , 1 ), dtype = np .bool_ )
247
- mask_datums = np .ones (n_datums , dtype = np .bool_ )
248
- mask_labels = np .ones (n_labels , dtype = np .bool_ )
249
-
250
- if datum_uids is not None :
251
- if isinstance (datum_uids , list ):
252
- datum_uids = np .array (
253
- [self .uid_to_index [uid ] for uid in datum_uids ],
254
- dtype = np .int32 ,
255
- )
256
- mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
257
- mask [
258
- np .isin (self ._detailed_pairs [:, 0 ].astype (int ), datum_uids )
259
- ] = True
260
- mask_pairs &= mask
261
-
262
- mask = np .zeros_like (mask_datums , dtype = np .bool_ )
263
- mask [datum_uids ] = True
264
- mask_datums &= mask
265
-
266
- if labels is not None :
267
- if isinstance (labels , list ):
268
- labels = np .array (
269
- [self .label_to_index [label ] for label in labels ]
270
- )
271
- mask = np .zeros_like (mask_pairs , dtype = np .bool_ )
272
- mask [
273
- np .isin (self ._detailed_pairs [:, 1 ].astype (int ), labels )
274
- ] = True
275
- mask_pairs &= mask
276
-
277
- mask = np .zeros_like (mask_labels , dtype = np .bool_ )
278
- mask [labels ] = True
279
- mask_labels &= mask
280
-
281
- mask = mask_datums [:, np .newaxis ] & mask_labels [np .newaxis , :]
282
- label_metadata_per_datum = self ._label_metadata_per_datum .copy ()
283
- label_metadata_per_datum [:, ~ mask ] = 0
284
-
285
- label_metadata = np .zeros_like (self ._label_metadata , dtype = np .int32 )
286
- label_metadata = np .transpose (
287
- np .sum (
288
- label_metadata_per_datum ,
289
- axis = 1 ,
290
- )
291
- )
292
-
293
- n_datums = int (np .sum (label_metadata [:, 0 ]))
294
-
295
- return Filter (
296
- indices = np .where (mask_pairs )[0 ],
297
- label_metadata = label_metadata ,
298
- n_datums = n_datums ,
299
- )
300
-
301
301
def compute_precision_recall (
302
302
self ,
303
303
score_thresholds : list [float ] = [0.0 ],
@@ -354,7 +354,7 @@ def compute_precision_recall(
354
354
355
355
metrics [MetricType .ROCAUC ] = [
356
356
ROCAUC (
357
- value = rocauc [label_idx ],
357
+ value = float ( rocauc [label_idx ]) ,
358
358
label = self .index_to_label [label_idx ],
359
359
)
360
360
for label_idx in range (label_metadata .shape [0 ])
@@ -363,7 +363,7 @@ def compute_precision_recall(
363
363
364
364
metrics [MetricType .mROCAUC ] = [
365
365
mROCAUC (
366
- value = mean_rocauc ,
366
+ value = float ( mean_rocauc ) ,
367
367
)
368
368
]
369
369
@@ -377,10 +377,10 @@ def compute_precision_recall(
377
377
row = counts [:, label_idx ]
378
378
metrics [MetricType .Counts ].append (
379
379
Counts (
380
- tp = row [:, 0 ].tolist (),
381
- fp = row [:, 1 ].tolist (),
382
- fn = row [:, 2 ].tolist (),
383
- tn = row [:, 3 ].tolist (),
380
+ tp = row [:, 0 ].astype ( int ). tolist (),
381
+ fp = row [:, 1 ].astype ( int ). tolist (),
382
+ fn = row [:, 2 ].astype ( int ). tolist (),
383
+ tn = row [:, 3 ].astype ( int ). tolist (),
384
384
** kwargs ,
385
385
)
386
386
)
@@ -391,25 +391,25 @@ def compute_precision_recall(
391
391
392
392
metrics [MetricType .Precision ].append (
393
393
Precision (
394
- value = precision [:, label_idx ].tolist (),
394
+ value = precision [:, label_idx ].astype ( float ). tolist (),
395
395
** kwargs ,
396
396
)
397
397
)
398
398
metrics [MetricType .Recall ].append (
399
399
Recall (
400
- value = recall [:, label_idx ].tolist (),
400
+ value = recall [:, label_idx ].astype ( float ). tolist (),
401
401
** kwargs ,
402
402
)
403
403
)
404
404
metrics [MetricType .Accuracy ].append (
405
405
Accuracy (
406
- value = accuracy [:, label_idx ].tolist (),
406
+ value = accuracy [:, label_idx ].astype ( float ). tolist (),
407
407
** kwargs ,
408
408
)
409
409
)
410
410
metrics [MetricType .F1 ].append (
411
411
F1 (
412
- value = f1_score [:, label_idx ].tolist (),
412
+ value = f1_score [:, label_idx ].astype ( float ). tolist (),
413
413
** kwargs ,
414
414
)
415
415
)
0 commit comments