Skip to content

Commit

Permalink
adding weights_only flag
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 7, 2025
1 parent 16b675c commit 8f53ebe
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions cellpose/resnet_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,12 @@ def load_model(self, filename, device=None):
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device)
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"))
state_dict = torch.load(filename, map_location=torch.device("cpu"),
weights_only=True)

if state_dict["output.2.weight"].shape[0] != self.nout:
for name in self.state_dict():
Expand Down Expand Up @@ -318,7 +319,8 @@ def load_model(self, filename, device=None):
else:
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)
state_dict = torch.load(filename, map_location=torch.device("cpu"),
weights_only=True)

self.load_state_dict(state_dict)

Expand Down

0 comments on commit 8f53ebe

Please sign in to comment.