@@ -268,11 +268,12 @@ def load_model(self, filename, device=None):
268268 device (torch.device, optional): The device to load the model on. Defaults to None.
269269 """
270270 if (device is not None ) and (device .type != "cpu" ):
271- state_dict = torch .load (filename , map_location = device )
271+ state_dict = torch .load (filename , map_location = device , weights_only = True )
272272 else :
273273 self .__init__ (self .nbase , self .nout , self .sz , self .mkldnn , self .conv_3D ,
274274 self .diam_mean )
275- state_dict = torch .load (filename , map_location = torch .device ("cpu" ))
275+ state_dict = torch .load (filename , map_location = torch .device ("cpu" ),
276+ weights_only = True )
276277
277278 if state_dict ["output.2.weight" ].shape [0 ] != self .nout :
278279 for name in self .state_dict ():
@@ -318,7 +319,8 @@ def load_model(self, filename, device=None):
318319 else :
319320 self .__init__ (self .nbase , self .nout , self .sz , self .mkldnn , self .conv_3D ,
320321 self .diam_mean )
321- state_dict = torch .load (filename , map_location = torch .device ("cpu" ), weights_only = True )
322+ state_dict = torch .load (filename , map_location = torch .device ("cpu" ),
323+ weights_only = True )
322324
323325 self .load_state_dict (state_dict )
324326
0 commit comments