Skip to content

Commit 1275eb7

Browse files
committed
migrating to bitstring rasters
1 parent 99dc75b commit 1275eb7

File tree

16 files changed

+412
-162
lines changed

16 files changed

+412
-162
lines changed

api/tests/functional-tests/backend/core/test_geometry.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
_convert_polygon_to_box,
1212
_convert_raster_to_box,
1313
_convert_raster_to_polygon,
14-
_raster_to_png_b64,
1514
convert_geometry,
1615
get_annotation_type,
1716
)
@@ -206,7 +205,7 @@ def test_convert_from_raster(
206205
and_(
207206
models.Annotation.box.is_(None),
208207
models.Annotation.polygon.is_(None),
209-
models.Annotation.raster.isnot(None),
208+
models.Annotation.bitmask_id.isnot(None),
210209
)
211210
)
212211
)
@@ -227,7 +226,7 @@ def test_convert_from_raster(
227226

228227
assert annotation.box is not None
229228
assert annotation.polygon is not None
230-
assert annotation.raster is not None
229+
assert annotation.bitmask_id is not None
231230

232231
converted_box = _load_box(db, annotation.box)
233232
converted_polygon = _load_polygon(db, annotation.polygon)
@@ -257,7 +256,7 @@ def test_convert_polygon_to_box(
257256
and_(
258257
models.Annotation.box.is_(None),
259258
models.Annotation.polygon.isnot(None),
260-
models.Annotation.raster.is_(None),
259+
models.Annotation.bitmask_id.is_(None),
261260
)
262261
)
263262
)
@@ -275,7 +274,7 @@ def test_convert_polygon_to_box(
275274

276275
assert annotation.box is not None
277276
assert annotation.polygon is not None
278-
assert annotation.raster is None
277+
assert annotation.bitmask_id is None
279278

280279
converted_box = _load_box(db, annotation.box)
281280

@@ -317,8 +316,8 @@ def test_create_raster_from_polygons(
317316

318317
# verify all rasters are equal
319318
raster_arrs = [
320-
Raster(mask=_raster_to_png_b64(db, r)).to_numpy()
321-
for r in db.scalars(select(models.Annotation.raster)).all()
319+
np.array([bit == "1" for bit in r.value]).reshape((r.height, r.width))
320+
for r in db.query(models.Bitmask).all()
322321
]
323322
assert len(raster_arrs) == 3
324323

api/tests/functional-tests/backend/metrics/test_segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,10 @@ def test__count_true_positives(
199199

200200
expected = _help_count_true_positives(
201201
gt_semantic_segs_create,
202-
[pred_semantic_segs_img1_create, pred_semantic_segs_img2_create],
202+
[
203+
pred_semantic_segs_img1_create,
204+
pred_semantic_segs_img2_create,
205+
],
203206
schemas.Label(key=k, value=v),
204207
)
205208

api/tests/functional-tests/crud/test_create_delete.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import pytest
7-
from geoalchemy2.functions import ST_AsText, ST_Count, ST_Polygon
7+
from geoalchemy2.functions import ST_AsText, ST_Polygon
88
from PIL import Image
99
from sqlalchemy import func, select
1010
from sqlalchemy.orm import Session
@@ -815,9 +815,12 @@ def test_create_predicted_segmentations_check_area_and_delete_model(
815815

816816
raster_counts = set(
817817
db.scalars(
818-
select(ST_Count(models.Annotation.raster)).where(
819-
models.Annotation.model_id.isnot(None)
818+
select(func.bit_count(models.Bitmask.value))
819+
.join(
820+
models.Annotation,
821+
models.Annotation.bitmask_id == models.Bitmask.id,
820822
)
823+
.where(models.Annotation.model_id.isnot(None))
821824
)
822825
)
823826

@@ -879,7 +882,9 @@ def test_segmentation_area_no_hole(
879882
],
880883
)
881884

882-
segmentation_count = db.scalar(select(ST_Count(models.Annotation.raster)))
885+
segmentation_count = db.scalar(
886+
select(func.bit_count(models.Bitmask.value))
887+
)
883888

884889
assert segmentation_count == math.ceil(45.5) # area of mask will be an int
885890

@@ -920,11 +925,9 @@ def test_segmentation_area_with_hole(
920925
],
921926
)
922927

923-
segmentation = db.scalar(select(models.Annotation))
924-
925928
# give tolerance of 2 pixels because of poly -> mask conversion
926-
assert segmentation
927-
assert (db.scalar(ST_Count(segmentation.raster)) - 92) <= 2
929+
assert db.scalar(select(models.Annotation))
930+
assert (db.scalar(func.bit_count(models.Bitmask.value)) - 92) <= 2
928931

929932

930933
def test_segmentation_area_multi_polygon(
@@ -966,13 +969,14 @@ def test_segmentation_area_multi_polygon(
966969
],
967970
)
968971

969-
segmentation = db.scalar(select(models.Annotation))
970-
971972
# the two shapes don't intersect so area should be sum of the areas
972973
# give tolerance of 2 pixels because of poly -> mask conversion
973-
assert segmentation
974+
assert db.scalar(select(models.Annotation))
974975
assert (
975-
abs(db.scalar(ST_Count(segmentation.raster)) - (math.ceil(45.5) + 92))
976+
abs(
977+
db.scalar(func.bit_count(models.Bitmask.value))
978+
- (math.ceil(45.5) + 92)
979+
)
976980
<= 2
977981
)
978982

@@ -1037,7 +1041,15 @@ def test_gt_seg_as_mask_or_polys(
10371041

10381042
shapes = db.scalars(
10391043
select(
1040-
ST_AsText(ST_Polygon(models.Annotation.raster)),
1044+
ST_AsText(
1045+
ST_Polygon(
1046+
func.bitstring_to_raster(
1047+
models.Bitmask.value,
1048+
models.Bitmask.height,
1049+
models.Bitmask.width,
1050+
)
1051+
)
1052+
),
10411053
)
10421054
).all()
10431055
assert len(shapes) == 2

api/valor_api/backend/core/annotation.py

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from geoalchemy2.functions import (
23
ST_AddBand,
34
ST_AsGeoJSON,
@@ -6,13 +7,12 @@
67
ST_MakeEmptyRaster,
78
ST_MapAlgebra,
89
)
9-
from sqlalchemy import CTE, ScalarSelect, and_, delete, func, insert, select
10+
from sqlalchemy import CTE, and_, delete, func, insert, select
1011
from sqlalchemy.exc import IntegrityError
1112
from sqlalchemy.orm import Session
1213

1314
from valor_api import schemas
1415
from valor_api.backend import models
15-
from valor_api.backend.core.geometry import _raster_to_png_b64
1616
from valor_api.backend.query import generate_query
1717
from valor_api.enums import ModelStatus, TableStatus, TaskType
1818

@@ -59,25 +59,39 @@ def _format_polygon(polygon: schemas.Polygon | None) -> str | None:
5959
return polygon.to_wkt() if polygon else None
6060

6161

62-
def _format_raster(
63-
raster: schemas.Raster | None,
64-
) -> ScalarSelect | bytes | None:
65-
return raster.to_psql() if raster else None
62+
# def _format_raster(
63+
# raster: schemas.Raster | None,
64+
# ) -> ScalarSelect | bytes | None:
65+
# return raster.to_psql() if raster else None
6666

6767

68-
def _format_bitmask(
68+
def _create_bitmask(
6969
db: Session,
7070
raster: schemas.Raster | None,
71-
) -> str | None:
71+
) -> int | None:
7272
"""
73-
Converts a Raster schema into a bitmask.
73+
Creates a bitmask from a raster schema.
7474
"""
7575
if raster is None:
7676
return None
7777
elif raster and raster.geometry:
7878
r = _create_raster_from_multipolygon(raster)
79-
return db.scalar(func.raster_to_bitstring(r.c.raster))
80-
return "".join("1" if b else "0" for b in raster.array.flatten())
79+
bitstring = db.scalar(func.raster_to_bitstring(r.c.raster))
80+
else:
81+
bitstring = "".join("1" if b else "0" for b in raster.array.flatten())
82+
83+
try:
84+
row = models.Bitmask(
85+
value=bitstring,
86+
height=raster.height,
87+
width=raster.width,
88+
)
89+
db.add(row)
90+
db.commit()
91+
except IntegrityError as e:
92+
db.rollback()
93+
raise e
94+
return row.id
8195

8296

8397
def _create_embedding(
@@ -159,8 +173,7 @@ def create_annotations(
159173
"meta": annotation.metadata,
160174
"box": _format_box(annotation.bounding_box),
161175
"polygon": _format_polygon(annotation.polygon),
162-
"raster": _format_raster(annotation.raster),
163-
"bitmask": _format_bitmask(db, annotation.raster),
176+
"bitmask_id": _create_bitmask(db, annotation.raster),
164177
"embedding_id": _create_embedding(
165178
db=db, value=annotation.embedding
166179
),
@@ -222,8 +235,7 @@ def create_skipped_annotations(
222235
meta=dict(),
223236
box=None,
224237
polygon=None,
225-
raster=None,
226-
bitmask=None,
238+
bitmask_id=None,
227239
embedding_id=None,
228240
text=None,
229241
context_list=None,
@@ -305,16 +317,19 @@ def get_annotation(
305317
)
306318

307319
# raster
308-
if annotation.raster is not None:
309-
datum = db.scalar(
310-
select(models.Datum).where(models.Datum.id == annotation.datum_id)
320+
if annotation.bitmask_id is not None:
321+
bitmask = (
322+
db.query(models.Bitmask)
323+
.where(models.Bitmask.id == annotation.bitmask_id)
324+
.scalar()
311325
)
312-
if datum is None:
313-
raise RuntimeError(
314-
"psql unexpectedly returned None instead of a Datum."
315-
)
316-
raster = schemas.Raster(
317-
mask=_raster_to_png_b64(db=db, raster=annotation.raster),
326+
if bitmask is None:
327+
raise ValueError("Expected bitmask to contain a value.")
328+
raster = schemas.Raster.from_numpy(
329+
mask=np.array(
330+
[int(bit) == 1 for bit in bitmask.value],
331+
dtype=bool,
332+
).reshape(bitmask.height, bitmask.width)
318333
)
319334

320335
# embedding
@@ -381,6 +396,28 @@ def get_annotations(
381396
]
382397

383398

399+
def _delete_linked_annotations(db: Session):
400+
# delete embeddings that are no longer referenced
401+
existing_ids = select(models.Annotation.embedding_id).where(
402+
models.Annotation.embedding_id.isnot(None)
403+
)
404+
db.execute(
405+
delete(models.Embedding).where(
406+
models.Embedding.id.not_in(existing_ids)
407+
)
408+
)
409+
db.commit()
410+
411+
# delete bitmasks that are no longer referenced
412+
existing_ids = select(models.Annotation.bitmask_id).where(
413+
models.Annotation.bitmask_id.isnot(None)
414+
)
415+
db.execute(
416+
delete(models.Bitmask).where(models.Bitmask.id.not_in(existing_ids))
417+
)
418+
db.commit()
419+
420+
384421
def delete_dataset_annotations(
385422
db: Session,
386423
dataset: models.Dataset,
@@ -421,16 +458,9 @@ def delete_dataset_annotations(
421458
)
422459
db.commit()
423460

424-
# delete embeddings (if they exist)
425-
existing_ids = select(models.Annotation.embedding_id).where(
426-
models.Annotation.embedding_id.isnot(None)
427-
)
428-
db.execute(
429-
delete(models.Embedding).where(
430-
models.Embedding.id.not_in(existing_ids)
431-
)
432-
)
433-
db.commit()
461+
# delete linked annotations
462+
_delete_linked_annotations(db=db)
463+
434464
except IntegrityError as e:
435465
db.rollback()
436466
raise e
@@ -475,16 +505,9 @@ def delete_model_annotations(
475505
)
476506
db.commit()
477507

478-
# delete embeddings (if they exist)
479-
existing_ids = select(models.Annotation.embedding_id).where(
480-
models.Annotation.embedding_id.isnot(None)
481-
)
482-
db.execute(
483-
delete(models.Embedding).where(
484-
models.Embedding.id.not_in(existing_ids)
485-
)
486-
)
487-
db.commit()
508+
# delete linked annotations
509+
_delete_linked_annotations(db=db)
510+
488511
except IntegrityError as e:
489512
db.rollback()
490513
raise e

api/valor_api/backend/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def get_n_groundtruth_rasters_in_dataset(db: Session, name: str) -> int:
375375
and_(
376376
models.Dataset.name == name,
377377
models.Dataset.status != enums.TableStatus.DELETING,
378-
models.Annotation.raster.isnot(None),
378+
models.Annotation.bitmask_id.isnot(None),
379379
)
380380
)
381381
.distinct()

0 commit comments

Comments
 (0)