Skip to content

Commit 3681c60

Browse files
committed
Merge branch 'main' into lumina2-lora
2 parents 24837a9 + 6fe05b9 commit 3681c60

File tree

10 files changed

+89
-10
lines changed

10 files changed

+89
-10
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,20 @@ def set_adapters(
661661
adapter_names: Union[List[str], str],
662662
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
663663
):
664-
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
664+
if isinstance(adapter_weights, dict):
665+
components_passed = set(adapter_weights.keys())
666+
lora_components = set(self._lora_loadable_modules)
667+
668+
invalid_components = sorted(components_passed - lora_components)
669+
if invalid_components:
670+
logger.warning(
671+
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
672+
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
673+
"to the invalid components will be removed and ignored."
674+
)
675+
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
665676

677+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
666678
adapter_weights = copy.deepcopy(adapter_weights)
667679

668680
# Expand weights into a list, one entry per adapter
@@ -697,12 +709,6 @@ def set_adapters(
697709
for adapter_name, weights in zip(adapter_names, adapter_weights):
698710
if isinstance(weights, dict):
699711
component_adapter_weights = weights.pop(component, None)
700-
701-
if component_adapter_weights is not None and not hasattr(self, component):
702-
logger.warning(
703-
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
704-
)
705-
706712
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
707713
logger.warning(
708714
(

src/diffusers/loaders/single_file.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from huggingface_hub import snapshot_download
2020
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
2121
from packaging import version
22+
from typing_extensions import Self
2223

2324
from ..utils import deprecate, is_transformers_available, logging
2425
from .single_file_utils import (
@@ -269,7 +270,7 @@ class FromSingleFileMixin:
269270

270271
@classmethod
271272
@validate_hf_hub_args
272-
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
273+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
273274
r"""
274275
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
275276
format. The pipeline is set in evaluation mode (`model.eval()`) by default.

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from huggingface_hub.utils import validate_hf_hub_args
22+
from typing_extensions import Self
2223

2324
from ..quantizers import DiffusersAutoQuantizer
2425
from ..utils import deprecate, is_accelerate_available, logging
@@ -148,7 +149,7 @@ class FromOriginalModelMixin:
148149

149150
@classmethod
150151
@validate_hf_hub_args
151-
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
152+
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
152153
r"""
153154
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
154155
is set in evaluation mode (`model.eval()`) by default.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def is_saveable_module(name, value):
324324
create_pr=create_pr,
325325
)
326326

327-
def to(self, *args, **kwargs):
327+
def to(self, *args, **kwargs) -> Self:
328328
r"""
329329
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
330330
arguments of `self.to(*args, **kwargs).`

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,7 @@ def test_simple_inference_with_text_lora_fused(self):
155155
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
156156
def test_simple_inference_with_text_lora_save_load(self):
157157
pass
158+
159+
@unittest.skip("Not supported in CogVideoX.")
160+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
161+
pass

tests/lora/test_lora_layers_flux.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ def test_lora_expansion_works_for_extra_keys(self):
262262
"LoRA should lead to different results.",
263263
)
264264

265+
@unittest.skip("Not supported in Flux.")
266+
def test_simple_inference_with_text_denoiser_block_scale(self):
267+
pass
268+
265269
@unittest.skip("Not supported in Flux.")
266270
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
267271
pass
@@ -270,6 +274,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
270274
def test_modify_padding_mode(self):
271275
pass
272276

277+
@unittest.skip("Not supported in Flux.")
278+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
279+
pass
280+
273281

274282
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
275283
pipeline_class = FluxControlPipeline
@@ -783,6 +791,10 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
783791
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
784792
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
785793

794+
@unittest.skip("Not supported in Flux.")
795+
def test_simple_inference_with_text_denoiser_block_scale(self):
796+
pass
797+
786798
@unittest.skip("Not supported in Flux.")
787799
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
788800
pass
@@ -791,6 +803,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
791803
def test_modify_padding_mode(self):
792804
pass
793805

806+
@unittest.skip("Not supported in Flux.")
807+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
808+
pass
809+
794810

795811
@slow
796812
@nightly

tests/lora/test_lora_layers_mochi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,7 @@ def test_simple_inference_with_text_lora_fused(self):
136136
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
137137
def test_simple_inference_with_text_lora_save_load(self):
138138
pass
139+
140+
@unittest.skip("Not supported in CogVideoX.")
141+
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
142+
pass

tests/lora/test_lora_layers_sd3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from diffusers.utils import load_image
3131
from diffusers.utils.import_utils import is_accelerate_available
3232
from diffusers.utils.testing_utils import (
33+
is_flaky,
3334
nightly,
3435
numpy_cosine_similarity_distance,
3536
require_big_gpu_with_torch_cuda,
@@ -128,6 +129,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
128129
def test_modify_padding_mode(self):
129130
pass
130131

132+
@is_flaky
133+
def test_multiple_wrong_adapter_name_raises_error(self):
134+
super().test_multiple_wrong_adapter_name_raises_error()
135+
131136

132137
@nightly
133138
@require_torch_gpu

tests/lora/test_lora_layers_sdxl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from diffusers.utils.import_utils import is_accelerate_available
3838
from diffusers.utils.testing_utils import (
3939
CaptureLogger,
40+
is_flaky,
4041
load_image,
4142
nightly,
4243
numpy_cosine_similarity_distance,
@@ -111,6 +112,10 @@ def tearDown(self):
111112
gc.collect()
112113
torch.cuda.empty_cache()
113114

115+
@is_flaky
116+
def test_multiple_wrong_adapter_name_raises_error(self):
117+
super().test_multiple_wrong_adapter_name_raises_error()
118+
114119

115120
@slow
116121
@nightly

tests/lora/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,43 @@ def test_wrong_adapter_name_raises_error(self):
11351135
pipe.set_adapters("adapter-1")
11361136
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
11371137

1138+
def test_multiple_wrong_adapter_name_raises_error(self):
1139+
scheduler_cls = self.scheduler_classes[0]
1140+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
1141+
pipe = self.pipeline_class(**components)
1142+
pipe = pipe.to(torch_device)
1143+
pipe.set_progress_bar_config(disable=None)
1144+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1145+
1146+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
1147+
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
1148+
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
1149+
1150+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
1151+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
1152+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
1153+
1154+
if self.has_two_text_encoders or self.has_three_text_encoders:
1155+
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
1156+
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
1157+
self.assertTrue(
1158+
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
1159+
)
1160+
1161+
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
1162+
logger = logging.get_logger("diffusers.loaders.lora_base")
1163+
logger.setLevel(30)
1164+
with CaptureLogger(logger) as cap_logger:
1165+
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
1166+
1167+
wrong_components = sorted(set(scale_with_wrong_components.keys()))
1168+
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
1169+
self.assertTrue(msg in str(cap_logger.out))
1170+
1171+
# test this works.
1172+
pipe.set_adapters("adapter-1")
1173+
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
1174+
11381175
def test_simple_inference_with_text_denoiser_block_scale(self):
11391176
"""
11401177
Tests a simple inference with lora attached to text encoder and unet, attaches

0 commit comments

Comments
 (0)