Skip to content

Commit f1546ee

Browse files
authored
Lite Shouldnt Return Numpy Types in metrics (#794)
1 parent 67776e9 commit f1546ee

File tree

6 files changed

+219
-108
lines changed

6 files changed

+219
-108
lines changed

lite/tests/classification/test_evaluator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from valor_lite.classification import Classification, DataLoader
23

34

@@ -23,3 +24,40 @@ def test_metadata_using_classification_example(
2324
"ignored_prediction_labels": [],
2425
"missing_prediction_labels": [],
2526
}
27+
28+
29+
def _flatten_metrics(m) -> list:
30+
if isinstance(m, dict):
31+
keys = list(m.keys())
32+
values = [
33+
inner_value
34+
for value in m.values()
35+
for inner_value in _flatten_metrics(value)
36+
]
37+
return keys + values
38+
elif isinstance(m, list):
39+
return [
40+
inner_value
41+
for value in m
42+
for inner_value in _flatten_metrics(value)
43+
]
44+
else:
45+
return [m]
46+
47+
48+
def test_output_types_dont_contain_numpy(
49+
basic_classifications: list[Classification],
50+
):
51+
manager = DataLoader()
52+
manager.add_data(basic_classifications)
53+
evaluator = manager.finalize()
54+
55+
metrics = evaluator.evaluate(
56+
score_thresholds=[0.25, 0.75],
57+
as_dict=True,
58+
)
59+
60+
values = _flatten_metrics(metrics)
61+
for value in values:
62+
if isinstance(value, (np.generic, np.ndarray)):
63+
raise TypeError(value)

lite/tests/object_detection/test_evaluator.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from valor_lite.object_detection import DataLoader, Detection, MetricType
23

34

@@ -86,3 +87,38 @@ def test_no_predictions(detections_no_predictions):
8687
assert m in expected_metrics
8788
for m in expected_metrics:
8889
assert m in actual_metrics
90+
91+
92+
def _flatten_metrics(m) -> list:
93+
if isinstance(m, dict):
94+
keys = list(m.keys())
95+
values = [
96+
inner_value
97+
for value in m.values()
98+
for inner_value in _flatten_metrics(value)
99+
]
100+
return keys + values
101+
elif isinstance(m, list):
102+
return [
103+
inner_value
104+
for value in m
105+
for inner_value in _flatten_metrics(value)
106+
]
107+
else:
108+
return [m]
109+
110+
111+
def test_output_types_dont_contain_numpy(basic_detections: list[Detection]):
112+
manager = DataLoader()
113+
manager.add_bounding_boxes(basic_detections)
114+
evaluator = manager.finalize()
115+
116+
metrics = evaluator.evaluate(
117+
score_thresholds=[0.25, 0.75],
118+
as_dict=True,
119+
)
120+
121+
values = _flatten_metrics(metrics)
122+
for value in values:
123+
if isinstance(value, (np.generic, np.ndarray)):
124+
raise TypeError

lite/tests/semantic_segmentation/test_evaluator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from valor_lite.semantic_segmentation import DataLoader, Segmentation
23

34

@@ -27,3 +28,39 @@ def test_metadata_using_large_random_segmentations(
2728
"number_of_groundtruth_pixels": 36000000,
2829
"number_of_prediction_pixels": 36000000,
2930
}
31+
32+
33+
def _flatten_metrics(m) -> list:
34+
if isinstance(m, dict):
35+
keys = list(m.keys())
36+
values = [
37+
inner_value
38+
for value in m.values()
39+
for inner_value in _flatten_metrics(value)
40+
]
41+
return keys + values
42+
elif isinstance(m, list):
43+
return [
44+
inner_value
45+
for value in m
46+
for inner_value in _flatten_metrics(value)
47+
]
48+
else:
49+
return [m]
50+
51+
52+
def test_output_types_dont_contain_numpy(
53+
segmentations_from_boxes: list[Segmentation],
54+
):
55+
manager = DataLoader()
56+
manager.add_data(segmentations_from_boxes)
57+
evaluator = manager.finalize()
58+
59+
metrics = evaluator.evaluate(
60+
as_dict=True,
61+
)
62+
63+
values = _flatten_metrics(metrics)
64+
for value in values:
65+
if isinstance(value, (np.generic, np.ndarray)):
66+
raise TypeError(value)

lite/valor_lite/classification/manager.py

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,86 @@ def metadata(self) -> dict:
111111
"missing_prediction_labels": self.missing_prediction_labels,
112112
}
113113

114+
def create_filter(
115+
self,
116+
datum_uids: list[str] | NDArray[np.int32] | None = None,
117+
labels: list[str] | NDArray[np.int32] | None = None,
118+
) -> Filter:
119+
"""
120+
Creates a boolean mask that can be passed to an evaluation.
121+
122+
Parameters
123+
----------
124+
datum_uids : list[str] | NDArray[np.int32], optional
125+
An optional list of string uids or a numpy array of uid indices.
126+
labels : list[str] | NDArray[np.int32], optional
127+
An optional list of labels or a numpy array of label indices.
128+
129+
Returns
130+
-------
131+
Filter
132+
A filter object that can be passed to the `evaluate` method.
133+
"""
134+
n_rows = self._detailed_pairs.shape[0]
135+
136+
n_datums = self._label_metadata_per_datum.shape[1]
137+
n_labels = self._label_metadata_per_datum.shape[2]
138+
139+
mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
140+
mask_datums = np.ones(n_datums, dtype=np.bool_)
141+
mask_labels = np.ones(n_labels, dtype=np.bool_)
142+
143+
if datum_uids is not None:
144+
if isinstance(datum_uids, list):
145+
datum_uids = np.array(
146+
[self.uid_to_index[uid] for uid in datum_uids],
147+
dtype=np.int32,
148+
)
149+
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
150+
mask[
151+
np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
152+
] = True
153+
mask_pairs &= mask
154+
155+
mask = np.zeros_like(mask_datums, dtype=np.bool_)
156+
mask[datum_uids] = True
157+
mask_datums &= mask
158+
159+
if labels is not None:
160+
if isinstance(labels, list):
161+
labels = np.array(
162+
[self.label_to_index[label] for label in labels]
163+
)
164+
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
165+
mask[
166+
np.isin(self._detailed_pairs[:, 1].astype(int), labels)
167+
] = True
168+
mask_pairs &= mask
169+
170+
mask = np.zeros_like(mask_labels, dtype=np.bool_)
171+
mask[labels] = True
172+
mask_labels &= mask
173+
174+
mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
175+
label_metadata_per_datum = self._label_metadata_per_datum.copy()
176+
label_metadata_per_datum[:, ~mask] = 0
177+
178+
label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
179+
label_metadata = np.transpose(
180+
np.sum(
181+
label_metadata_per_datum,
182+
axis=1,
183+
)
184+
)
185+
186+
n_datums = int(np.sum(label_metadata[:, 0]))
187+
188+
return Filter(
189+
indices=np.where(mask_pairs)[0],
190+
label_metadata=label_metadata,
191+
n_datums=n_datums,
192+
)
193+
114194
def _unpack_confusion_matrix(
115195
self,
116196
confusion_matrix: NDArray[np.float64],
@@ -218,86 +298,6 @@ def _unpack_missing_predictions(
218298
for gt_label_idx in range(number_of_labels)
219299
}
220300

221-
def create_filter(
222-
self,
223-
datum_uids: list[str] | NDArray[np.int32] | None = None,
224-
labels: list[str] | NDArray[np.int32] | None = None,
225-
) -> Filter:
226-
"""
227-
Creates a boolean mask that can be passed to an evaluation.
228-
229-
Parameters
230-
----------
231-
datum_uids : list[str] | NDArray[np.int32], optional
232-
An optional list of string uids or a numpy array of uid indices.
233-
labels : list[str] | NDArray[np.int32], optional
234-
An optional list of labels or a numpy array of label indices.
235-
236-
Returns
237-
-------
238-
Filter
239-
A filter object that can be passed to the `evaluate` method.
240-
"""
241-
n_rows = self._detailed_pairs.shape[0]
242-
243-
n_datums = self._label_metadata_per_datum.shape[1]
244-
n_labels = self._label_metadata_per_datum.shape[2]
245-
246-
mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
247-
mask_datums = np.ones(n_datums, dtype=np.bool_)
248-
mask_labels = np.ones(n_labels, dtype=np.bool_)
249-
250-
if datum_uids is not None:
251-
if isinstance(datum_uids, list):
252-
datum_uids = np.array(
253-
[self.uid_to_index[uid] for uid in datum_uids],
254-
dtype=np.int32,
255-
)
256-
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
257-
mask[
258-
np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
259-
] = True
260-
mask_pairs &= mask
261-
262-
mask = np.zeros_like(mask_datums, dtype=np.bool_)
263-
mask[datum_uids] = True
264-
mask_datums &= mask
265-
266-
if labels is not None:
267-
if isinstance(labels, list):
268-
labels = np.array(
269-
[self.label_to_index[label] for label in labels]
270-
)
271-
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
272-
mask[
273-
np.isin(self._detailed_pairs[:, 1].astype(int), labels)
274-
] = True
275-
mask_pairs &= mask
276-
277-
mask = np.zeros_like(mask_labels, dtype=np.bool_)
278-
mask[labels] = True
279-
mask_labels &= mask
280-
281-
mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
282-
label_metadata_per_datum = self._label_metadata_per_datum.copy()
283-
label_metadata_per_datum[:, ~mask] = 0
284-
285-
label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
286-
label_metadata = np.transpose(
287-
np.sum(
288-
label_metadata_per_datum,
289-
axis=1,
290-
)
291-
)
292-
293-
n_datums = int(np.sum(label_metadata[:, 0]))
294-
295-
return Filter(
296-
indices=np.where(mask_pairs)[0],
297-
label_metadata=label_metadata,
298-
n_datums=n_datums,
299-
)
300-
301301
def compute_precision_recall(
302302
self,
303303
score_thresholds: list[float] = [0.0],
@@ -354,7 +354,7 @@ def compute_precision_recall(
354354

355355
metrics[MetricType.ROCAUC] = [
356356
ROCAUC(
357-
value=rocauc[label_idx],
357+
value=float(rocauc[label_idx]),
358358
label=self.index_to_label[label_idx],
359359
)
360360
for label_idx in range(label_metadata.shape[0])
@@ -363,7 +363,7 @@ def compute_precision_recall(
363363

364364
metrics[MetricType.mROCAUC] = [
365365
mROCAUC(
366-
value=mean_rocauc,
366+
value=float(mean_rocauc),
367367
)
368368
]
369369

@@ -377,10 +377,10 @@ def compute_precision_recall(
377377
row = counts[:, label_idx]
378378
metrics[MetricType.Counts].append(
379379
Counts(
380-
tp=row[:, 0].tolist(),
381-
fp=row[:, 1].tolist(),
382-
fn=row[:, 2].tolist(),
383-
tn=row[:, 3].tolist(),
380+
tp=row[:, 0].astype(int).tolist(),
381+
fp=row[:, 1].astype(int).tolist(),
382+
fn=row[:, 2].astype(int).tolist(),
383+
tn=row[:, 3].astype(int).tolist(),
384384
**kwargs,
385385
)
386386
)
@@ -391,25 +391,25 @@ def compute_precision_recall(
391391

392392
metrics[MetricType.Precision].append(
393393
Precision(
394-
value=precision[:, label_idx].tolist(),
394+
value=precision[:, label_idx].astype(float).tolist(),
395395
**kwargs,
396396
)
397397
)
398398
metrics[MetricType.Recall].append(
399399
Recall(
400-
value=recall[:, label_idx].tolist(),
400+
value=recall[:, label_idx].astype(float).tolist(),
401401
**kwargs,
402402
)
403403
)
404404
metrics[MetricType.Accuracy].append(
405405
Accuracy(
406-
value=accuracy[:, label_idx].tolist(),
406+
value=accuracy[:, label_idx].astype(float).tolist(),
407407
**kwargs,
408408
)
409409
)
410410
metrics[MetricType.F1].append(
411411
F1(
412-
value=f1_score[:, label_idx].tolist(),
412+
value=f1_score[:, label_idx].astype(float).tolist(),
413413
**kwargs,
414414
)
415415
)

0 commit comments

Comments
 (0)