Skip to content

Commit fb54499

Browse files
BenjaminBossansayakpaulhlkyyiyixuxustevhliu
authored
[LoRA] Implement hot-swapping of LoRA (#9453)
* [WIP][LoRA] Implement hot-swapping of LoRA This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it. * Reviewer feedback * Reviewer feedback, adjust test * Fix, doc * Make fix * Fix for possible g++ error * Add test for recompilation w/o hotswapping * Make hotswap work Requires huggingface/peft#2366 More changes to make hotswapping work. Together with the mentioned PEFT PR, the tests pass for me locally. List of changes: - docstring for hotswap - remove code copied from PEFT, import from PEFT now - adjustments to PeftAdapterMixin.load_lora_adapter (unfortunately, some state dict renaming was necessary, LMK if there is a better solution) - adjustments to UNet2DConditionLoadersMixin._process_lora: LMK if this is even necessary or not, I'm unsure what the overall relationship is between this and PeftAdapterMixin.load_lora_adapter - also in UNet2DConditionLoadersMixin._process_lora, I saw that there is no LoRA unloading when loading the adapter fails, so I added it there (in line with what happens in PeftAdapterMixin.load_lora_adapter) - rewritten tests to avoid shelling out, make the test more precise by making sure that the outputs align, parametrize it - also checked the pipeline code mentioned in this comment: #9453 (comment); when running this inside the with torch._dynamo.config.patch(error_on_recompile=True) context, there is no error, so I think hotswapping is now working with pipelines. * Address reviewer feedback: - Revert deprecated method - Fix PEFT doc link to main - Don't use private function - Clarify magic numbers - Add pipeline test Moreover: - Extend docstrings - Extend existing test for outputs != 0 - Extend existing test for wrong adapter name * Change order of test decorators parameterized.expand seems to ignore skip decorators if added in last place (i.e. innermost decorator). * Split model and pipeline tests Also increase test coverage by also targeting conv2d layers (support of which was added recently on the PEFT PR). * Reviewer feedback: Move decorator to test classes ... instead of having them on each test method. * Apply suggestions from code review Co-authored-by: hlky <[email protected]> * Reviewer feedback: version check, TODO comment * Add enable_lora_hotswap method * Reviewer feedback: check _lora_loadable_modules * Revert changes in unet.py * Add possibility to ignore enabled at wrong time * Fix docstrings * Log possible PEFT error, test * Raise helpful error if hotswap not supported I.e. for the text encoder * Formatting * More linter * More ruff * Doc-builder complaint * Update docstring: - mention no text encoder support yet - make it clear that LoRA is meant - mention that same adapter name should be passed * Fix error in docstring * Update more methods with hotswap argument - SDXL - SD3 - Flux No changes were made to load_lora_into_transformer. * Add hotswap argument to load_lora_into_transformer For SD3 and Flux. Use shorter docstring for brevity. * Extend docstrings * Add version guards to tests * Formatting * Fix LoRA loading call to add prefix=None See: #10187 (comment) * Run make fix-copies * Add hot swap documentation to the docs * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: hlky <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent 723dbdd commit fb54499

File tree

6 files changed

+1274
-23
lines changed

6 files changed

+1274
-23
lines changed

docs/source/en/using-diffusers/loading_adapters.md

+53
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,59 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
194194

195195
</Tip>
196196

197+
### Hotswapping LoRA adapters
198+
199+
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
200+
201+
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
202+
203+
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
204+
205+
```python
206+
pipe = ...
207+
# load adapter 1 as normal
208+
pipeline.load_lora_weights(file_name_adapter_1)
209+
# generate some images with adapter 1
210+
...
211+
# now hot swap the 2nd adapter
212+
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
213+
# generate images with adapter 2
214+
```
215+
216+
217+
<Tip warning={true}>
218+
219+
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
220+
221+
</Tip>
222+
223+
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
224+
225+
```python
226+
pipe = ...
227+
# call this extra method
228+
pipe.enable_lora_hotswap(target_rank=max_rank)
229+
# now load adapter 1
230+
pipe.load_lora_weights(file_name_adapter_1)
231+
# now compile the unet of the pipeline
232+
pipe.unet = torch.compile(pipeline.unet, ...)
233+
# generate some images with adapter 1
234+
...
235+
# now hot swap adapter 2
236+
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
237+
# generate images with adapter 2
238+
```
239+
240+
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
241+
242+
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
243+
244+
<Tip>
245+
246+
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
247+
248+
</Tip>
249+
197250
### Kohya and TheLastBen
198251

199252
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.

src/diffusers/loaders/lora_base.py

+25
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
316316
adapter_name=None,
317317
_pipeline=None,
318318
low_cpu_mem_usage=False,
319+
hotswap: bool = False,
319320
):
320321
if not USE_PEFT_BACKEND:
321322
raise ValueError("PEFT backend is required for this method.")
@@ -341,6 +342,10 @@ def _load_lora_into_text_encoder(
341342
# their prefixes.
342343
prefix = text_encoder_name if prefix is None else prefix
343344

345+
# Safe prefix to check with.
346+
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
347+
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
348+
344349
# Load the layers corresponding to text encoder and make necessary adjustments.
345350
if prefix is not None:
346351
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
@@ -908,3 +913,23 @@ def lora_scale(self) -> float:
908913
# property function that returns the lora scale which can be set at run time by the pipeline.
909914
# if _lora_scale has not been set, return 1
910915
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
916+
917+
def enable_lora_hotswap(self, **kwargs) -> None:
918+
"""Enables the possibility to hotswap LoRA adapters.
919+
920+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
921+
the loaded adapters differ.
922+
923+
Args:
924+
target_rank (`int`):
925+
The highest rank among all the adapters that will be loaded.
926+
check_compiled (`str`, *optional*, defaults to `"error"`):
927+
How to handle the case when the model is already compiled, which should generally be avoided. The
928+
options are:
929+
- "error" (default): raise an error
930+
- "warn": issue a warning
931+
- "ignore": do nothing
932+
"""
933+
for key, component in self.components.items():
934+
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
935+
component.enable_lora_hotswap(**kwargs)

0 commit comments

Comments
 (0)