@@ -268,11 +268,12 @@ def load_model(self, filename, device=None):
268
268
device (torch.device, optional): The device to load the model on. Defaults to None.
269
269
"""
270
270
if (device is not None ) and (device .type != "cpu" ):
271
- state_dict = torch .load (filename , map_location = device )
271
+ state_dict = torch .load (filename , map_location = device , weights_only = True )
272
272
else :
273
273
self .__init__ (self .nbase , self .nout , self .sz , self .mkldnn , self .conv_3D ,
274
274
self .diam_mean )
275
- state_dict = torch .load (filename , map_location = torch .device ("cpu" ))
275
+ state_dict = torch .load (filename , map_location = torch .device ("cpu" ),
276
+ weights_only = True )
276
277
277
278
if state_dict ["output.2.weight" ].shape [0 ] != self .nout :
278
279
for name in self .state_dict ():
@@ -318,7 +319,8 @@ def load_model(self, filename, device=None):
318
319
else :
319
320
self .__init__ (self .nbase , self .nout , self .sz , self .mkldnn , self .conv_3D ,
320
321
self .diam_mean )
321
- state_dict = torch .load (filename , map_location = torch .device ("cpu" ), weights_only = True )
322
+ state_dict = torch .load (filename , map_location = torch .device ("cpu" ),
323
+ weights_only = True )
322
324
323
325
self .load_state_dict (state_dict )
324
326
0 commit comments