Skip to content

Commit 74fd735

Browse files
Add draft for lora text encoder scale (#3626)
* Add draft for lora text encoder scale * Improve naming * fix: training dreambooth lora script. * Apply suggestions from code review * Update examples/dreambooth/train_dreambooth_lora.py * Apply suggestions from code review * Apply suggestions from code review * add lora mixin when fit * add lora mixin when fit * add lora mixin when fit * fix more * fix more --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 2de9e2d commit 74fd735

29 files changed

+406
-78
lines changed

CONTRIBUTING.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ Awesome! Tell us what problem it solved for you.
125125

126126
You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
127127

128-
#### 2.3 Feedback.
128+
#### 2.3 Feedback.
129129

130130
Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
131131
If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
132132

133133
You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
134134

135-
#### 2.4 Technical questions.
135+
#### 2.4 Technical questions.
136136

137137
Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
138138
why this part of the code is difficult to understand.
@@ -394,8 +394,8 @@ passes. You should run the tests impacted by your changes like this:
394394
```bash
395395
$ pytest tests/<TEST_TO_RUN>.py
396396
```
397-
398-
Before you run the tests, please make sure you install the dependencies required for testing. You can do so
397+
398+
Before you run the tests, please make sure you install the dependencies required for testing. You can do so
399399
with this command:
400400

401401
```bash

PHILOSOPHY.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefor
2727

2828
## Simple over easy
2929

30-
As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
30+
As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
3131
- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.
3232
- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.
3333
- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.
34-
- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. Dreambooth or textual inversion training
34+
- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. Dreambooth or textual inversion training
3535
is very simple thanks to diffusers' ability to separate single components of the diffusion pipeline.
3636

3737
## Tweakable, contributor-friendly over abstraction
3838

39-
For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
39+
For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
4040
In short, just like Transformers does for modeling files, diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.
41-
Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
41+
Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
4242
**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:
4343
- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.
4444
- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.
@@ -47,10 +47,10 @@ Functions, long code blocks, and even classes can be copied across multiple file
4747
At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look
4848
at [this blog post](https://huggingface.co/blog/transformers-design-philosophy).
4949

50-
In diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
50+
In diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
5151
as [DDPM](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [UnCLIP (Dalle-2)](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/unclip#overview) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models#diffusers.UNet2DConditionModel).
5252

53-
Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
53+
Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
5454
We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
5555

5656
## Design Philosophy in Details
@@ -89,17 +89,17 @@ The following design principles are followed:
8989
- Models should by default have the highest precision and lowest performance setting.
9090
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
9191
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
92-
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
92+
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
9393
readable longterm, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
9494

9595
### Schedulers
9696

9797
Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.
9898

9999
The following design principles are followed:
100-
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
101-
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
102-
- One scheduler python file corresponds to one scheduler algorithm (as might be defined in a paper).
100+
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
101+
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
102+
- One scheduler python file corresponds to one scheduler algorithm (as might be defined in a paper).
103103
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
104104
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
105105
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./using-diffusers/schedulers.mdx).

README.md

+9-9
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ We recommend installing 🤗 Diffusers in a virtual environment from PyPi or Con
3030
### PyTorch
3131

3232
With `pip` (official package):
33-
33+
3434
```bash
3535
pip install --upgrade diffusers[torch]
3636
```
@@ -107,7 +107,7 @@ Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to l
107107
| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
108108
## Contribution
109109

110-
We ❤️ contributions from the open-source community!
110+
We ❤️ contributions from the open-source community!
111111
If you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
112112
You can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library.
113113
- See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute
@@ -128,7 +128,7 @@ just hang out ☕.
128128
</tr>
129129
<tr style="border-top: 2px solid black">
130130
<td>Unconditional Image Generation</td>
131-
<td><a href="https://huggingface.co/docs/diffusers/api/pipelines/ddpm"> DDPM </a></td>
131+
<td><a href="https://huggingface.co/docs/diffusers/api/pipelines/ddpm"> DDPM </a></td>
132132
<td><a href="https://huggingface.co/google/ddpm-ema-church-256"> google/ddpm-ema-church-256 </a></td>
133133
</tr>
134134
<tr style="border-top: 2px solid black">
@@ -185,13 +185,13 @@ just hang out ☕.
185185

186186
## Popular libraries using 🧨 Diffusers
187187

188-
- https://github.com/microsoft/TaskMatrix
189-
- https://github.com/invoke-ai/InvokeAI
190-
- https://github.com/apple/ml-stable-diffusion
191-
- https://github.com/Sanster/lama-cleaner
188+
- https://github.com/microsoft/TaskMatrix
189+
- https://github.com/invoke-ai/InvokeAI
190+
- https://github.com/apple/ml-stable-diffusion
191+
- https://github.com/Sanster/lama-cleaner
192192
- https://github.com/IDEA-Research/Grounded-Segment-Anything
193-
- https://github.com/ashawkey/stable-dreamfusion
194-
- https://github.com/deep-floyd/IF
193+
- https://github.com/ashawkey/stable-dreamfusion
194+
- https://github.com/deep-floyd/IF
195195
- https://github.com/bentoml/BentoML
196196
- https://github.com/bmaltais/kohya_ss
197197
- +3000 other amazing GitHub repositories 💪

docs/source/_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
# ! pip install git+https://github.com/huggingface/diffusers.git
77
"""
88

9-
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
9+
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]

docs/source/en/training/lora.mdx

+8
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,14 @@ pipe.load_lora_weights(lora_model_id)
260260
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
261261
```
262262

263+
<Tip>
264+
265+
If your LoRA parameters involve the UNet as well as the Text Encoder, then passing
266+
`cross_attention_kwargs={"scale": 0.5}` will apply the `scale` value to both the UNet
267+
and the Text Encoder.
268+
269+
</Tip>
270+
263271
Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is preferred to [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
264272
[`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] can handle the following situations:
265273

src/diffusers/loaders.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
852852
weight_name = kwargs.pop("weight_name", None)
853853
use_safetensors = kwargs.pop("use_safetensors", None)
854854

855+
# set lora scale to a reasonable default
856+
self._lora_scale = 1.0
857+
855858
if use_safetensors and not is_safetensors_available():
856859
raise ValueError(
857860
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
@@ -953,6 +956,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
953956
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
954957
warnings.warn(warn_message)
955958

959+
@property
960+
def lora_scale(self) -> float:
961+
# property function that returns the lora scale which can be set at run time by the pipeline.
962+
# if _lora_scale has not been set, return 1
963+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
964+
956965
@property
957966
def text_encoder_lora_attn_procs(self):
958967
if hasattr(self, "_text_encoder_lora_attn_procs"):
@@ -1000,7 +1009,8 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
10001009
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
10011010
def make_new_forward(old_forward, lora_layer):
10021011
def new_forward(x):
1003-
return old_forward(x) + lora_layer(x)
1012+
result = old_forward(x) + self.lora_scale * lora_layer(x)
1013+
return result
10041014

10051015
return new_forward
10061016

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ...configuration_utils import FrozenDict
2626
from ...image_processor import VaeImageProcessor
27-
from ...loaders import TextualInversionLoaderMixin
27+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
@@ -52,7 +52,7 @@
5252

5353

5454
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
55-
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
55+
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
5656
r"""
5757
Pipeline for text-to-image generation using Alt Diffusion.
5858
@@ -291,6 +291,7 @@ def _encode_prompt(
291291
negative_prompt=None,
292292
prompt_embeds: Optional[torch.FloatTensor] = None,
293293
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294+
lora_scale: Optional[float] = None,
294295
):
295296
r"""
296297
Encodes the prompt into text encoder hidden states.
@@ -315,7 +316,14 @@ def _encode_prompt(
315316
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
316317
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
317318
argument.
319+
lora_scale (`float`, *optional*):
320+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
318321
"""
322+
# set lora scale so that monkey patched LoRA
323+
# function of text encoder can correctly access it
324+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
325+
self._lora_scale = lora_scale
326+
319327
if prompt is not None and isinstance(prompt, str):
320328
batch_size = 1
321329
elif prompt is not None and isinstance(prompt, list):
@@ -653,6 +661,9 @@ def __call__(
653661
do_classifier_free_guidance = guidance_scale > 1.0
654662

655663
# 3. Encode input prompt
664+
text_encoder_lora_scale = (
665+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
666+
)
656667
prompt_embeds = self._encode_prompt(
657668
prompt,
658669
device,
@@ -661,6 +672,7 @@ def __call__(
661672
negative_prompt,
662673
prompt_embeds=prompt_embeds,
663674
negative_prompt_embeds=negative_prompt_embeds,
675+
lora_scale=text_encoder_lora_scale,
664676
)
665677

666678
# 4. Prepare timesteps

0 commit comments

Comments
 (0)