diff --git a/lite/tests/semantic_segmentation/test_annotation.py b/lite/tests/semantic_segmentation/test_annotation.py index 999dd5240..b6efb2c64 100644 --- a/lite/tests/semantic_segmentation/test_annotation.py +++ b/lite/tests/semantic_segmentation/test_annotation.py @@ -1,6 +1,10 @@ import numpy as np import pytest -from valor_lite.semantic_segmentation import Bitmask, Segmentation +from valor_lite.semantic_segmentation import ( + Bitmask, + Segmentation, + generate_segmentation, +) def test_bitmask(): @@ -78,3 +82,25 @@ def test_segmentation(): predictions=[], ) assert "missing predictions" in str(e) + + +def test_generate_segmentation(): + + segmentation = generate_segmentation( + datum_uid="uid1", + number_of_unique_labels=3, + mask_height=2, + mask_width=3, + ) + + assert segmentation.uid == "uid1" + assert segmentation.shape == (2, 3) + assert segmentation.size == 6 + + assert len(segmentation.groundtruths) == 3 + assert all(gt.mask.dtype == np.bool_ for gt in segmentation.groundtruths) + assert all(gt.mask.shape == (2, 3) for gt in segmentation.groundtruths) + + assert len(segmentation.predictions) + assert all(pd.mask.dtype == np.bool_ for pd in segmentation.predictions) + assert all(pd.mask.shape == (2, 3) for pd in segmentation.predictions)