Skip to content

Commit 8b3930e

Browse files
committed
Merge remote-tracking branch 'origin' into fix_fill_holes_and_remove_small_masks
2 parents baf9d1c + 7a0f2c4 commit 8b3930e

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

cellpose/gui/gui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,7 @@ def initialize_model(self, model_name=None, custom=False):
17701770
if model_name is None or custom:
17711771
self.get_model_path(custom=custom)
17721772
if not os.path.exists(self.current_model_path):
1773-
raise ValueError("need to specify model (use dropdown)")
1773+
raise ValueError("Model file not found: need to specify model (use dropdown)")
17741774

17751775
if model_name is None or not isinstance(model_name, str):
17761776
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
@@ -1867,7 +1867,7 @@ def compute_cprob(self):
18671867
self.logger.error("Flows don't exist, try running model again.")
18681868
return
18691869

1870-
maski = dynamics.resize_and_compute_masks(
1870+
maski = dynamics.compute_masks_and_clean(
18711871
dP=dP,
18721872
cellprob=cellprob,
18731873
niter=niter,

cellpose/vit_sam.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,15 @@ def forward(self, x):
8181

8282
return x1, torch.zeros((x.shape[0], 256), device=x.device)
8383

84-
def load_model(self, PATH, device, strict = False):
84+
def load_model(self, PATH, device, strict = False):
8585
state_dict = torch.load(PATH, map_location = device, weights_only=True)
8686
keys = [k for k in state_dict.keys()]
87+
88+
# loudly fail on attempt to load not cp4 model:
89+
w2_data = state_dict.get('W2', None)
90+
if w2_data == None:
91+
raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.')
92+
8793
if keys[0][:7] == "module.":
8894
from collections import OrderedDict
8995
new_state_dict = OrderedDict()

tests/test_import.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import pytest
2+
3+
14
def test_cellpose_imports_without_error():
25
import cellpose
36
from cellpose import models, core
@@ -28,3 +31,18 @@ def itest_model_dir():
2831
model = models.CellposeModel(pretrained_model='cpsam')
2932
masks = model.eval(np.random.randn(256, 256))[0]
3033
assert masks.shape == (256, 256)
34+
35+
36+
def test_load_cp3_fail():
37+
from cellpose.models import CellposeModel, MODEL_DIR
38+
from cellpose import utils
39+
40+
cyto3_model_path = (MODEL_DIR / 'cyto3').absolute()
41+
42+
if not cyto3_model_path.exists():
43+
url = 'https://www.cellpose.org/models/cyto3'
44+
utils.download_url_to_file(url, cyto3_model_path, progress=False)
45+
46+
with pytest.raises(ValueError):
47+
# using `pretrained_model=cyto3` just loads the cpsam model unless the path is given
48+
model = CellposeModel(pretrained_model=str(cyto3_model_path))

0 commit comments

Comments
 (0)