Skip to content

Commit 8f53ebe

Browse files
adding weights_only flag
1 parent 16b675c commit 8f53ebe

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

cellpose/resnet_torch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,12 @@ def load_model(self, filename, device=None):
268268
device (torch.device, optional): The device to load the model on. Defaults to None.
269269
"""
270270
if (device is not None) and (device.type != "cpu"):
271-
state_dict = torch.load(filename, map_location=device)
271+
state_dict = torch.load(filename, map_location=device, weights_only=True)
272272
else:
273273
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
274274
self.diam_mean)
275-
state_dict = torch.load(filename, map_location=torch.device("cpu"))
275+
state_dict = torch.load(filename, map_location=torch.device("cpu"),
276+
weights_only=True)
276277

277278
if state_dict["output.2.weight"].shape[0] != self.nout:
278279
for name in self.state_dict():
@@ -318,7 +319,8 @@ def load_model(self, filename, device=None):
318319
else:
319320
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
320321
self.diam_mean)
321-
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)
322+
state_dict = torch.load(filename, map_location=torch.device("cpu"),
323+
weights_only=True)
322324

323325
self.load_state_dict(state_dict)
324326

0 commit comments

Comments
 (0)