diff --git a/integration_tests/client/models/test_model.py b/integration_tests/client/models/test_model.py index 28d1514e0..91710e18b 100644 --- a/integration_tests/client/models/test_model.py +++ b/integration_tests/client/models/test_model.py @@ -192,6 +192,12 @@ def test_create_image_model_with_predicted_detections( assert fx_points in db_point_lists +def _convert_bitstring_to_numpy(bitmask: models.Bitmask): + return np.array([bit == "1" for bit in bitmask.value]).reshape( + (bitmask.height, bitmask.width) + ) + + def test_create_model_with_predicted_segmentations( db: Session, client: Client, @@ -217,17 +223,21 @@ def test_create_model_with_predicted_segmentations( # grab the segmentation from the db, recover the mask, and check # its equal to the mask the client sent over db_annotations = ( - db.query(models.Annotation) + db.query(models.Bitmask) + .join( + models.Annotation, + models.Annotation.bitmask_id == models.Bitmask.id, + ) .where(models.Annotation.model_id.isnot(None)) .all() ) if db_annotations[0].datum_id < db_annotations[1].datum_id: - raster_uid1 = db_annotations[0].raster - raster_uid2 = db_annotations[1].raster + raster_uid1 = _convert_bitstring_to_numpy(db_annotations[0]) + raster_uid2 = _convert_bitstring_to_numpy(db_annotations[1]) else: - raster_uid1 = db_annotations[1].raster - raster_uid2 = db_annotations[0].raster + raster_uid1 = _convert_bitstring_to_numpy(db_annotations[1]) + raster_uid2 = _convert_bitstring_to_numpy(db_annotations[0]) # test raster 1 png_from_db = db.scalar(ST_AsPNG(raster_uid1))