22
22
from huggingface_hub import model_info
23
23
from huggingface_hub .constants import HF_HUB_OFFLINE
24
24
from huggingface_hub .utils import validate_hf_hub_args
25
- from packaging import version
26
25
from torch import nn
27
26
28
- from .. import __version__
29
- from ..models .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT , load_state_dict
27
+ from ..models .modeling_utils import load_state_dict
30
28
from ..utils import (
31
29
USE_PEFT_BACKEND ,
32
30
_get_model_file ,
33
31
convert_state_dict_to_diffusers ,
34
32
convert_state_dict_to_peft ,
35
- convert_unet_state_dict_to_peft ,
36
33
delete_adapter_layers ,
37
34
get_adapter_name ,
38
35
get_peft_kwargs ,
@@ -119,13 +116,10 @@ def load_lora_weights(
119
116
if not is_correct_format :
120
117
raise ValueError ("Invalid LoRA checkpoint." )
121
118
122
- low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
123
-
124
119
self .load_lora_into_unet (
125
120
state_dict ,
126
121
network_alphas = network_alphas ,
127
122
unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet ,
128
- low_cpu_mem_usage = low_cpu_mem_usage ,
129
123
adapter_name = adapter_name ,
130
124
_pipeline = self ,
131
125
)
@@ -136,7 +130,6 @@ def load_lora_weights(
136
130
if not hasattr (self , "text_encoder" )
137
131
else self .text_encoder ,
138
132
lora_scale = self .lora_scale ,
139
- low_cpu_mem_usage = low_cpu_mem_usage ,
140
133
adapter_name = adapter_name ,
141
134
_pipeline = self ,
142
135
)
@@ -193,16 +186,8 @@ def lora_state_dict(
193
186
allowed by Git.
194
187
subfolder (`str`, *optional*, defaults to `""`):
195
188
The subfolder location of a model file within a larger model repository on the Hub or locally.
196
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
197
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
198
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
199
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
200
- argument to `True` will raise an error.
201
- mirror (`str`, *optional*):
202
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
203
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
204
- information.
205
-
189
+ weight_name (`str`, *optional*, defaults to None):
190
+ Name of the serialized state dict file.
206
191
"""
207
192
# Load the main state dict first which has the LoRA layers for either of
208
193
# UNet and text encoder or both.
@@ -383,9 +368,7 @@ def _optionally_disable_offloading(cls, _pipeline):
383
368
return (is_model_cpu_offload , is_sequential_cpu_offload )
384
369
385
370
@classmethod
386
- def load_lora_into_unet (
387
- cls , state_dict , network_alphas , unet , low_cpu_mem_usage = None , adapter_name = None , _pipeline = None
388
- ):
371
+ def load_lora_into_unet (cls , state_dict , network_alphas , unet , adapter_name = None , _pipeline = None ):
389
372
"""
390
373
This will load the LoRA layers specified in `state_dict` into `unet`.
391
374
@@ -395,109 +378,30 @@ def load_lora_into_unet(
395
378
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
396
379
encoder lora layers.
397
380
network_alphas (`Dict[str, float]`):
398
- See `LoRALinearLayer` for more details.
381
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
382
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
383
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
399
384
unet (`UNet2DConditionModel`):
400
385
The UNet model to load the LoRA layers into.
401
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
402
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
403
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
404
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
405
- argument to `True` will raise an error.
406
386
adapter_name (`str`, *optional*):
407
387
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
408
388
`default_{i}` where i is the total number of adapters being loaded.
409
389
"""
410
390
if not USE_PEFT_BACKEND :
411
391
raise ValueError ("PEFT backend is required for this method." )
412
392
413
- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
414
-
415
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
416
393
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
417
394
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
418
395
# their prefixes.
419
396
keys = list (state_dict .keys ())
397
+ only_text_encoder = all (key .startswith (cls .text_encoder_name ) for key in keys )
420
398
421
- if all (key .startswith (cls .unet_name ) or key . startswith ( cls . text_encoder_name ) for key in keys ):
399
+ if any (key .startswith (cls .unet_name ) for key in keys ) and not only_text_encoder :
422
400
# Load the layers corresponding to UNet.
423
401
logger .info (f"Loading { cls .unet_name } ." )
424
-
425
- unet_keys = [k for k in keys if k .startswith (cls .unet_name )]
426
- state_dict = {k .replace (f"{ cls .unet_name } ." , "" ): v for k , v in state_dict .items () if k in unet_keys }
427
-
428
- if network_alphas is not None :
429
- alpha_keys = [k for k in network_alphas .keys () if k .startswith (cls .unet_name )]
430
- network_alphas = {
431
- k .replace (f"{ cls .unet_name } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys
432
- }
433
-
434
- else :
435
- # Otherwise, we're dealing with the old format. This means the `state_dict` should only
436
- # contain the module names of the `unet` as its keys WITHOUT any prefix.
437
- if not USE_PEFT_BACKEND :
438
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
439
- logger .warning (warn_message )
440
-
441
- if len (state_dict .keys ()) > 0 :
442
- if adapter_name in getattr (unet , "peft_config" , {}):
443
- raise ValueError (
444
- f"Adapter name { adapter_name } already in use in the Unet - please select a new adapter name."
445
- )
446
-
447
- state_dict = convert_unet_state_dict_to_peft (state_dict )
448
-
449
- if network_alphas is not None :
450
- # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
451
- # `convert_unet_state_dict_to_peft` method.
452
- network_alphas = convert_unet_state_dict_to_peft (network_alphas )
453
-
454
- rank = {}
455
- for key , val in state_dict .items ():
456
- if "lora_B" in key :
457
- rank [key ] = val .shape [1 ]
458
-
459
- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = True )
460
- if "use_dora" in lora_config_kwargs :
461
- if lora_config_kwargs ["use_dora" ]:
462
- if is_peft_version ("<" , "0.9.0" ):
463
- raise ValueError (
464
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
465
- )
466
- else :
467
- if is_peft_version ("<" , "0.9.0" ):
468
- lora_config_kwargs .pop ("use_dora" )
469
- lora_config = LoraConfig (** lora_config_kwargs )
470
-
471
- # adapter_name
472
- if adapter_name is None :
473
- adapter_name = get_adapter_name (unet )
474
-
475
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
476
- # otherwise loading LoRA weights will lead to an error
477
- is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
478
-
479
- inject_adapter_in_model (lora_config , unet , adapter_name = adapter_name )
480
- incompatible_keys = set_peft_model_state_dict (unet , state_dict , adapter_name )
481
-
482
- if incompatible_keys is not None :
483
- # check only for unexpected keys
484
- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
485
- if unexpected_keys :
486
- logger .warning (
487
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
488
- f" { unexpected_keys } . "
489
- )
490
-
491
- # Offload back.
492
- if is_model_cpu_offload :
493
- _pipeline .enable_model_cpu_offload ()
494
- elif is_sequential_cpu_offload :
495
- _pipeline .enable_sequential_cpu_offload ()
496
- # Unsafe code />
497
-
498
- unet .load_attn_procs (
499
- state_dict , network_alphas = network_alphas , low_cpu_mem_usage = low_cpu_mem_usage , _pipeline = _pipeline
500
- )
402
+ unet .load_attn_procs (
403
+ state_dict , network_alphas = network_alphas , adapter_name = adapter_name , _pipeline = _pipeline
404
+ )
501
405
502
406
@classmethod
503
407
def load_lora_into_text_encoder (
@@ -507,7 +411,6 @@ def load_lora_into_text_encoder(
507
411
text_encoder ,
508
412
prefix = None ,
509
413
lora_scale = 1.0 ,
510
- low_cpu_mem_usage = None ,
511
414
adapter_name = None ,
512
415
_pipeline = None ,
513
416
):
@@ -527,11 +430,6 @@ def load_lora_into_text_encoder(
527
430
lora_scale (`float`):
528
431
How much to scale the output of the lora linear layer before it is added with the output of the regular
529
432
lora layer.
530
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
531
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
532
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
533
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
534
- argument to `True` will raise an error.
535
433
adapter_name (`str`, *optional*):
536
434
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537
435
`default_{i}` where i is the total number of adapters being loaded.
@@ -541,8 +439,6 @@ def load_lora_into_text_encoder(
541
439
542
440
from peft import LoraConfig
543
441
544
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
545
-
546
442
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
547
443
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
548
444
# their prefixes.
@@ -625,9 +521,7 @@ def load_lora_into_text_encoder(
625
521
# Unsafe code />
626
522
627
523
@classmethod
628
- def load_lora_into_transformer (
629
- cls , state_dict , network_alphas , transformer , low_cpu_mem_usage = None , adapter_name = None , _pipeline = None
630
- ):
524
+ def load_lora_into_transformer (cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None ):
631
525
"""
632
526
This will load the LoRA layers specified in `state_dict` into `transformer`.
633
527
@@ -640,19 +534,12 @@ def load_lora_into_transformer(
640
534
See `LoRALinearLayer` for more details.
641
535
unet (`UNet2DConditionModel`):
642
536
The UNet model to load the LoRA layers into.
643
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
644
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
645
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
646
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
647
- argument to `True` will raise an error.
648
537
adapter_name (`str`, *optional*):
649
538
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
650
539
`default_{i}` where i is the total number of adapters being loaded.
651
540
"""
652
541
from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
653
542
654
- low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
655
-
656
543
keys = list (state_dict .keys ())
657
544
658
545
transformer_keys = [k for k in keys if k .startswith (cls .transformer_name )]
@@ -846,22 +733,11 @@ def unload_lora_weights(self):
846
733
>>> ...
847
734
```
848
735
"""
849
- unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
850
-
851
736
if not USE_PEFT_BACKEND :
852
- if version .parse (__version__ ) > version .parse ("0.23" ):
853
- logger .warning (
854
- "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
855
- "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
856
- )
737
+ raise ValueError ("PEFT backend is required for this method." )
857
738
858
- for _ , module in unet .named_modules ():
859
- if hasattr (module , "set_lora_layer" ):
860
- module .set_lora_layer (None )
861
- else :
862
- recurse_remove_peft_layers (unet )
863
- if hasattr (unet , "peft_config" ):
864
- del unet .peft_config
739
+ unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
740
+ unet .unload_lora ()
865
741
866
742
# Safe to call the following regardless of LoRA.
867
743
self ._remove_text_encoder_monkey_patch ()
0 commit comments