Skip to content

Commit 4da810b

Browse files
authored
Remove insecure torch.load calls (#7393)
update
1 parent 161c6e1 commit 4da810b

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

src/diffusers/loaders/ip_adapter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from huggingface_hub.utils import validate_hf_hub_args
2020
from safetensors import safe_open
2121

22-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
22+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
2323
from ..utils import (
2424
_get_model_file,
2525
is_accelerate_available,
@@ -182,7 +182,7 @@ def load_ip_adapter(
182182
elif key.startswith("ip_adapter."):
183183
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
184184
else:
185-
state_dict = torch.load(model_file, map_location="cpu")
185+
state_dict = load_state_dict(model_file)
186186
else:
187187
state_dict = pretrained_model_name_or_path_or_dict
188188

src/diffusers/loaders/lora.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch import nn
2626

2727
from .. import __version__
28-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
28+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
2929
from ..utils import (
3030
USE_PEFT_BACKEND,
3131
_get_model_file,
@@ -281,7 +281,7 @@ def lora_state_dict(
281281
subfolder=subfolder,
282282
user_agent=user_agent,
283283
)
284-
state_dict = torch.load(model_file, map_location="cpu")
284+
state_dict = load_state_dict(model_file)
285285
else:
286286
state_dict = pretrained_model_name_or_path_or_dict
287287

src/diffusers/loaders/textual_inversion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from huggingface_hub.utils import validate_hf_hub_args
1919
from torch import nn
2020

21+
from ..models.modeling_utils import load_state_dict
2122
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
2223

2324

@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
100101
subfolder=subfolder,
101102
user_agent=user_agent,
102103
)
103-
state_dict = torch.load(model_file, map_location="cpu")
104+
state_dict = load_state_dict(model_file)
104105
else:
105106
state_dict = pretrained_model_name_or_path
106107

src/diffusers/loaders/unet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
IPAdapterPlusImageProjection,
3232
MultiIPAdapterImageProjection,
3333
)
34-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
34+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
3535
from ..utils import (
3636
USE_PEFT_BACKEND,
3737
_get_model_file,
@@ -214,7 +214,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
214214
subfolder=subfolder,
215215
user_agent=user_agent,
216216
)
217-
state_dict = torch.load(model_file, map_location="cpu")
217+
state_dict = load_state_dict(model_file)
218218
else:
219219
state_dict = pretrained_model_name_or_path_or_dict
220220

src/diffusers/models/modeling_utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
108108
if file_extension == SAFETENSORS_FILE_EXTENSION:
109109
return safetensors.torch.load_file(checkpoint_file, device="cpu")
110110
else:
111-
return torch.load(checkpoint_file, map_location="cpu")
111+
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
112+
return torch.load(
113+
checkpoint_file,
114+
map_location="cpu",
115+
**weights_only_kwarg,
116+
)
112117
except Exception as e:
113118
try:
114119
with open(checkpoint_file) as f:

0 commit comments

Comments
 (0)