Skip to content

Commit e6eec15

Browse files
committed
explicitly fail loading cp4 model in cp3
1 parent 3864748 commit e6eec15

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

cellpose/resnet_torch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def save_model(self, filename):
258258
filename (str): The path to the file where the model will be saved.
259259
"""
260260
torch.save(self.state_dict(), filename)
261-
261+
262262
def load_model(self, filename, device=None):
263263
"""
264264
Load the model from a file.
@@ -272,9 +272,13 @@ def load_model(self, filename, device=None):
272272
else:
273273
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
274274
self.diam_mean)
275-
state_dict = torch.load(filename, map_location=torch.device("cpu"),
275+
state_dict = torch.load(filename, map_location=torch.device("cpu"),
276276
weights_only=True)
277277

278+
# check if state_dict is for cp4 by looking for W2 key:
279+
assert 'W2' not in state_dict.keys(), f"""The model file {filename} appears to be for CP4,
280+
which is not compatible with CP3. Please use a CP3 pretrained model file. """
281+
278282
if state_dict["output.2.weight"].shape[0] != self.nout:
279283
for name in self.state_dict():
280284
if "output" not in name:

0 commit comments

Comments
 (0)