Skip to content

Commit f08576d

Browse files
committed
adjust tests to compare image sizes
1 parent ef9821f commit f08576d

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tests/test_output.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from cellpose import io, metrics, utils, models
22
import pytest
33
from subprocess import check_output, STDOUT
4+
from pathlib import Path
45
import os
56
import numpy as np
67

@@ -43,7 +44,9 @@ def clear_output(data_dir, image_names):
4344
(True, True, 40),
4445
(True, True, None),
4546
(False, True, None),
46-
(False, False, None)
47+
(False, False, None),
48+
(True, False, None),
49+
(True, False, 40)
4750
]
4851
)
4952
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter):
@@ -56,11 +59,21 @@ def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer,
5659

5760
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample, diameter=diameter)
5861

59-
if not compute_masks or diameter:
62+
if not compute_masks:
6063
# not compute_masks won't return masks so can't check
61-
# different diameter will give different masks, so can't check
6264
return
6365

66+
if diameter and compute_masks:
67+
# size of masks will be different, so need to adjust calculation
68+
masks_gt_file = Path(str(img_file).replace('_tif.tif', '_tif_cp4_gt_masks.png'))
69+
masks_gt = io.imread_2D(masks_gt_file)
70+
71+
masks_pred_shape = [int(s * diameter/30) for s in masks_pred.shape]
72+
assert [a == b for a, b in zip(masks_gt.shape[:2], masks_pred_shape[:2])]
73+
74+
# don't compare the images, because they are different sizes and won't match
75+
return
76+
6477
io.imsave(data_dir / '2D' / (img_file.stem + "_cp_masks.png"), masks_pred)
6578
# flowsp_pred = np.concatenate([flows_pred[1], flows_pred[2][None, ...]], axis=0)
6679
# mse = np.sqrt((flowsp_pred - flowps) ** 2).sum()

0 commit comments

Comments
 (0)