Skip to content

Commit

Permalink
Merge branch 'main' into lumina2-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Feb 19, 2025
2 parents 24837a9 + 6fe05b9 commit 3681c60
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 10 deletions.
20 changes: 13 additions & 7 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,20 @@ def set_adapters(
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)

invalid_components = sorted(components_passed - lora_components)
if invalid_components:
logger.warning(
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
"to the invalid components will be removed and ignored."
)
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)

# Expand weights into a list, one entry per adapter
Expand Down Expand Up @@ -697,12 +709,6 @@ def set_adapters(
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)

if component_adapter_weights is not None and not hasattr(self, component):
logger.warning(
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
)

if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
logger.warning(
(
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from typing_extensions import Self

from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
Expand Down Expand Up @@ -269,7 +270,7 @@ class FromSingleFileMixin:

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self

from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
Expand Down Expand Up @@ -148,7 +149,7 @@ class FromOriginalModelMixin:

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def is_saveable_module(name, value):
create_pr=create_pr,
)

def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> Self:
r"""
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
Expand Down
4 changes: 4 additions & 0 deletions tests/lora/test_lora_layers_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,7 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
16 changes: 16 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def test_lora_expansion_works_for_extra_keys(self):
"LoRA should lead to different results.",
)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand All @@ -270,6 +274,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass


class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
Expand Down Expand Up @@ -783,6 +791,10 @@ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand All @@ -791,6 +803,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass


@slow
@nightly
Expand Down
4 changes: 4 additions & 0 deletions tests/lora/test_lora_layers_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,7 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
5 changes: 5 additions & 0 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_flaky,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
Expand Down Expand Up @@ -128,6 +129,10 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass

@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()


@nightly
@require_torch_gpu
Expand Down
5 changes: 5 additions & 0 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
Expand Down Expand Up @@ -111,6 +112,10 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()


@slow
@nightly
Expand Down
37 changes: 37 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,43 @@ def test_wrong_adapter_name_raises_error(self):
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]

def test_multiple_wrong_adapter_name_raises_error(self):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)

scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)

wrong_components = sorted(set(scale_with_wrong_components.keys()))
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
self.assertTrue(msg in str(cap_logger.out))

# test this works.
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]

def test_simple_inference_with_text_denoiser_block_scale(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
Expand Down

0 comments on commit 3681c60

Please sign in to comment.