Skip to content

Commit 2c55f20

Browse files
committed
add more elaborate tests + corresponding logic in from_array
1 parent 315c7d8 commit 2c55f20

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

Diff for: src/cellmap_schemas/annotation.py

+28-16
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class SemanticSegmentation(BaseModel, extra="forbid"):
181181
182182
type: Literal["semantic_segmentation"]
183183
Must be the string 'semantic_segmentation'.
184-
encoding: dict[Union[Possibility, Literal["present"]], int]
184+
encoding: dict[SemanticPossibility, int]
185185
This dict represents the mapping from possibilities to numeric values. The keys
186186
must be strings in the set `{'unknown', 'absent', 'present'}`, and the values
187187
must be numeric values contained in the array described by this metadata.
@@ -280,24 +280,36 @@ def from_array(
280280
if complement_counts == "auto":
281281
complement_counts_parsed: dict[SemanticPossibility, int] | dict[
282282
InstancePossibility, int
283-
]
284-
num_unknown = (array == annotation_type.encoding["unknown"]).sum()
285-
num_absent = (array == annotation_type.encoding["absent"]).sum()
286-
num_present = array.size - (num_unknown + num_absent)
283+
] = {}
284+
285+
num_unknown: int = 0
286+
num_absent: int = 0
287+
num_present: int = 0
288+
289+
if "unknown" in annotation_type.encoding:
290+
num_unknown = (array == annotation_type.encoding["unknown"]).sum()
291+
292+
if "absent" in annotation_type.encoding:
293+
num_absent = (array == annotation_type.encoding["absent"]).sum()
294+
295+
if "present" in annotation_type.encoding:
296+
num_present = array.size - (num_unknown + num_absent)
287297

288298
if isinstance(annotation_type, SemanticSegmentation):
289-
complement_counts_parsed: dict[SemanticPossibility, int] = {
290-
"unknown": num_unknown,
291-
"absent": num_absent,
292-
"present": num_present,
293-
}
299+
if "unknown" in annotation_type.encoding:
300+
complement_counts_parsed["unknown"] = num_unknown
301+
if "absent" in annotation_type.encoding:
302+
complement_counts_parsed["absent"] = num_absent
303+
if "present" in annotation_type.encoding:
304+
complement_counts_parsed["present"] = num_present
305+
294306
elif isinstance(annotation_type, InstanceSegmentation):
295-
complement_counts_parsed: dict[InstancePossibility, int] = {
296-
"unknown": num_unknown,
297-
"absent": num_absent,
298-
}
299-
else:
300-
complement_counts_parsed = complement_counts # type: ignore
307+
if "unknown" in annotation_type.encoding:
308+
complement_counts_parsed["unknown"] = num_unknown
309+
if "absent" in annotation_type.encoding:
310+
complement_counts_parsed["absent"] = num_absent
311+
else:
312+
complement_counts_parsed = complement_counts # type: ignore
301313

302314
return cls(
303315
class_name=class_name,

Diff for: tests/test_annotation.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,33 @@
99
AnnotationGroupAttrs,
1010
CropGroup,
1111
CropGroupAttrs,
12+
InstancePossibility,
13+
InstanceSegmentation,
14+
SemanticPossibility,
1215
SemanticSegmentation,
1316
wrap_attributes,
1417
)
1518
import numpy as np
1619
import zarr
1720

1821

19-
def test_cropgroup():
22+
@pytest.mark.parametrize(
23+
"annotation_cls, encoding",
24+
[
25+
(SemanticSegmentation, {"absent": 0}),
26+
(SemanticSegmentation, {"absent": 0, "present": 1}),
27+
(SemanticSegmentation, {"absent": 0, "present": 1, "unknown": 3}),
28+
(InstanceSegmentation, {"absent": 0}),
29+
(InstanceSegmentation, {"absent": 0, "unknown": 3}),
30+
],
31+
)
32+
def test_cropgroup(
33+
annotation_cls: type[SemanticSegmentation] | type[InstanceSegmentation],
34+
encoding: dict[SemanticPossibility, int] | dict[InstancePossibility, int],
35+
):
2036
ClassNamesT = Literal["foo", "bar"]
21-
class_names = ["foo", "bar"]
22-
ann_type = SemanticSegmentation(encoding={"absent": 0})
37+
class_names: list[Literal["foo", "bar"]] = ["foo", "bar"]
38+
ann_type = annotation_cls(encoding=encoding) # type: ignore
2339
arrays = [np.zeros(10) for class_name in class_names]
2440
crop_group_attrs = CropGroupAttrs[ClassNamesT](
2541
name="foo",
@@ -79,3 +95,24 @@ def test_cropgroup():
7995
)
8096
with pytest.raises(ValueError, match=match):
8197
CropGroup.from_zarr(stored)
98+
99+
100+
@pytest.mark.parametrize(
101+
"annotation_cls, encoding",
102+
[
103+
(SemanticSegmentation, {"absent": 0}),
104+
(SemanticSegmentation, {"absent": 0, "present": 1}),
105+
(SemanticSegmentation, {"absent": 0, "present": 1, "unknown": 3}),
106+
(InstanceSegmentation, {"absent": 0}),
107+
(InstanceSegmentation, {"absent": 0, "unknown": 3}),
108+
],
109+
)
110+
def test_annotation_attrs_from_array(annotation_cls, encoding) -> None:
111+
array = np.array([0, 1, 0, 3])
112+
attrs = AnnotationArrayAttrs.from_array(
113+
array,
114+
class_name="foo",
115+
annotation_type=annotation_cls(encoding=encoding),
116+
complement_counts="auto",
117+
)
118+
assert attrs.complement_counts["absent"] == 2

0 commit comments

Comments
 (0)