Skip to content

Commit a0542c1

Browse files
authored
[LoRA] Remove legacy LoRA code and related adjustments (huggingface#8316)
* remove legacy code from load_attn_procs. * finish first draft * fix more. * fix more * add test * add serialization support. * fix-copies * require peft backend for lora tests * style * fix test * fix loading. * empty * address benjamin's feedback.
1 parent a8ad666 commit a0542c1

File tree

7 files changed

+392
-425
lines changed

7 files changed

+392
-425
lines changed

.github/workflows/pr_test_peft_backend.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,21 @@ jobs:
111111
-s -v \
112112
--make-reports=tests_${{ matrix.config.report }} \
113113
tests/lora/
114+
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
115+
-s -v \
116+
--make-reports=tests_models_lora_${{ matrix.config.report }} \
117+
tests/models/ -k "lora"
118+
119+
120+
- name: Failure short reports
121+
if: ${{ failure() }}
122+
run: |
123+
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
124+
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
125+
126+
- name: Test suite reports artifacts
127+
if: ${{ always() }}
128+
uses: actions/upload-artifact@v2
129+
with:
130+
name: pr_${{ matrix.config.report }}_test_reports
131+
path: reports

.github/workflows/push_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,17 @@ jobs:
189189
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
190190
--make-reports=tests_peft_cuda \
191191
tests/lora/
192+
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
193+
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
194+
--make-reports=tests_peft_cuda_models_lora \
195+
tests/models/
192196
193197
- name: Failure short reports
194198
if: ${{ failure() }}
195199
run: |
196200
cat reports/tests_peft_cuda_stats.txt
197201
cat reports/tests_peft_cuda_failures_short.txt
202+
cat reports/tests_peft_cuda_models_lora_failures_short.txt
198203
199204
- name: Test suite reports artifacts
200205
if: ${{ always() }}

src/diffusers/loaders/lora.py

Lines changed: 16 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,14 @@
2222
from huggingface_hub import model_info
2323
from huggingface_hub.constants import HF_HUB_OFFLINE
2424
from huggingface_hub.utils import validate_hf_hub_args
25-
from packaging import version
2625
from torch import nn
2726

28-
from .. import __version__
29-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
27+
from ..models.modeling_utils import load_state_dict
3028
from ..utils import (
3129
USE_PEFT_BACKEND,
3230
_get_model_file,
3331
convert_state_dict_to_diffusers,
3432
convert_state_dict_to_peft,
35-
convert_unet_state_dict_to_peft,
3633
delete_adapter_layers,
3734
get_adapter_name,
3835
get_peft_kwargs,
@@ -119,13 +116,10 @@ def load_lora_weights(
119116
if not is_correct_format:
120117
raise ValueError("Invalid LoRA checkpoint.")
121118

122-
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
123-
124119
self.load_lora_into_unet(
125120
state_dict,
126121
network_alphas=network_alphas,
127122
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
128-
low_cpu_mem_usage=low_cpu_mem_usage,
129123
adapter_name=adapter_name,
130124
_pipeline=self,
131125
)
@@ -136,7 +130,6 @@ def load_lora_weights(
136130
if not hasattr(self, "text_encoder")
137131
else self.text_encoder,
138132
lora_scale=self.lora_scale,
139-
low_cpu_mem_usage=low_cpu_mem_usage,
140133
adapter_name=adapter_name,
141134
_pipeline=self,
142135
)
@@ -193,16 +186,8 @@ def lora_state_dict(
193186
allowed by Git.
194187
subfolder (`str`, *optional*, defaults to `""`):
195188
The subfolder location of a model file within a larger model repository on the Hub or locally.
196-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
197-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
198-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
199-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
200-
argument to `True` will raise an error.
201-
mirror (`str`, *optional*):
202-
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
203-
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
204-
information.
205-
189+
weight_name (`str`, *optional*, defaults to None):
190+
Name of the serialized state dict file.
206191
"""
207192
# Load the main state dict first which has the LoRA layers for either of
208193
# UNet and text encoder or both.
@@ -383,9 +368,7 @@ def _optionally_disable_offloading(cls, _pipeline):
383368
return (is_model_cpu_offload, is_sequential_cpu_offload)
384369

385370
@classmethod
386-
def load_lora_into_unet(
387-
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
388-
):
371+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
389372
"""
390373
This will load the LoRA layers specified in `state_dict` into `unet`.
391374
@@ -395,109 +378,30 @@ def load_lora_into_unet(
395378
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
396379
encoder lora layers.
397380
network_alphas (`Dict[str, float]`):
398-
See `LoRALinearLayer` for more details.
381+
The value of the network alpha used for stable learning and preventing underflow. This value has the
382+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
383+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
399384
unet (`UNet2DConditionModel`):
400385
The UNet model to load the LoRA layers into.
401-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
402-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
403-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
404-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
405-
argument to `True` will raise an error.
406386
adapter_name (`str`, *optional*):
407387
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
408388
`default_{i}` where i is the total number of adapters being loaded.
409389
"""
410390
if not USE_PEFT_BACKEND:
411391
raise ValueError("PEFT backend is required for this method.")
412392

413-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
414-
415-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
416393
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
417394
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
418395
# their prefixes.
419396
keys = list(state_dict.keys())
397+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
420398

421-
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
399+
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
422400
# Load the layers corresponding to UNet.
423401
logger.info(f"Loading {cls.unet_name}.")
424-
425-
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
426-
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
427-
428-
if network_alphas is not None:
429-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
430-
network_alphas = {
431-
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
432-
}
433-
434-
else:
435-
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
436-
# contain the module names of the `unet` as its keys WITHOUT any prefix.
437-
if not USE_PEFT_BACKEND:
438-
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
439-
logger.warning(warn_message)
440-
441-
if len(state_dict.keys()) > 0:
442-
if adapter_name in getattr(unet, "peft_config", {}):
443-
raise ValueError(
444-
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
445-
)
446-
447-
state_dict = convert_unet_state_dict_to_peft(state_dict)
448-
449-
if network_alphas is not None:
450-
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
451-
# `convert_unet_state_dict_to_peft` method.
452-
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
453-
454-
rank = {}
455-
for key, val in state_dict.items():
456-
if "lora_B" in key:
457-
rank[key] = val.shape[1]
458-
459-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
460-
if "use_dora" in lora_config_kwargs:
461-
if lora_config_kwargs["use_dora"]:
462-
if is_peft_version("<", "0.9.0"):
463-
raise ValueError(
464-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
465-
)
466-
else:
467-
if is_peft_version("<", "0.9.0"):
468-
lora_config_kwargs.pop("use_dora")
469-
lora_config = LoraConfig(**lora_config_kwargs)
470-
471-
# adapter_name
472-
if adapter_name is None:
473-
adapter_name = get_adapter_name(unet)
474-
475-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
476-
# otherwise loading LoRA weights will lead to an error
477-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
478-
479-
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
480-
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
481-
482-
if incompatible_keys is not None:
483-
# check only for unexpected keys
484-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
485-
if unexpected_keys:
486-
logger.warning(
487-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
488-
f" {unexpected_keys}. "
489-
)
490-
491-
# Offload back.
492-
if is_model_cpu_offload:
493-
_pipeline.enable_model_cpu_offload()
494-
elif is_sequential_cpu_offload:
495-
_pipeline.enable_sequential_cpu_offload()
496-
# Unsafe code />
497-
498-
unet.load_attn_procs(
499-
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
500-
)
402+
unet.load_attn_procs(
403+
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
404+
)
501405

502406
@classmethod
503407
def load_lora_into_text_encoder(
@@ -507,7 +411,6 @@ def load_lora_into_text_encoder(
507411
text_encoder,
508412
prefix=None,
509413
lora_scale=1.0,
510-
low_cpu_mem_usage=None,
511414
adapter_name=None,
512415
_pipeline=None,
513416
):
@@ -527,11 +430,6 @@ def load_lora_into_text_encoder(
527430
lora_scale (`float`):
528431
How much to scale the output of the lora linear layer before it is added with the output of the regular
529432
lora layer.
530-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
531-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
532-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
533-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
534-
argument to `True` will raise an error.
535433
adapter_name (`str`, *optional*):
536434
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537435
`default_{i}` where i is the total number of adapters being loaded.
@@ -541,8 +439,6 @@ def load_lora_into_text_encoder(
541439

542440
from peft import LoraConfig
543441

544-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
545-
546442
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
547443
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
548444
# their prefixes.
@@ -625,9 +521,7 @@ def load_lora_into_text_encoder(
625521
# Unsafe code />
626522

627523
@classmethod
628-
def load_lora_into_transformer(
629-
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
630-
):
524+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
631525
"""
632526
This will load the LoRA layers specified in `state_dict` into `transformer`.
633527
@@ -640,19 +534,12 @@ def load_lora_into_transformer(
640534
See `LoRALinearLayer` for more details.
641535
unet (`UNet2DConditionModel`):
642536
The UNet model to load the LoRA layers into.
643-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
644-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
645-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
646-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
647-
argument to `True` will raise an error.
648537
adapter_name (`str`, *optional*):
649538
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
650539
`default_{i}` where i is the total number of adapters being loaded.
651540
"""
652541
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
653542

654-
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
655-
656543
keys = list(state_dict.keys())
657544

658545
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
@@ -846,22 +733,11 @@ def unload_lora_weights(self):
846733
>>> ...
847734
```
848735
"""
849-
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
850-
851736
if not USE_PEFT_BACKEND:
852-
if version.parse(__version__) > version.parse("0.23"):
853-
logger.warning(
854-
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
855-
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
856-
)
737+
raise ValueError("PEFT backend is required for this method.")
857738

858-
for _, module in unet.named_modules():
859-
if hasattr(module, "set_lora_layer"):
860-
module.set_lora_layer(None)
861-
else:
862-
recurse_remove_peft_layers(unet)
863-
if hasattr(unet, "peft_config"):
864-
del unet.peft_config
739+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
740+
unet.unload_lora()
865741

866742
# Safe to call the following regardless of LoRA.
867743
self._remove_text_encoder_monkey_patch()

0 commit comments

Comments
 (0)