Skip to content

Commit 7956c36

Browse files
yiyixuxuyiyixuxusayakpaulDN6
authored
add a from_pipe method to DiffusionPipeline (#7241)
* add from_pipe --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 5266ab7 commit 7956c36

File tree

22 files changed

+675
-63
lines changed

22 files changed

+675
-63
lines changed

docs/source/en/using-diffusers/loading.md

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,210 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
179179
)
180180
```
181181

182+
### Switch loaded pippelines
183+
184+
There are many diffuser pipelines that use the same pre-trained model as [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`], but they implement specific features to help you achieve better generation results. This guide will show you how to use the `from_pipe` API to create multiple pipelines without increasing memory usage. By using this approach, you can easily switch between pipelines to use different features.
185+
186+
Let's take an example where we first create a [`StableDiffusionPipeline`] and then reuse the already loaded model components to create a [`StableDiffusionSAGPipeline`] to enhance generation quality.
187+
188+
we will generate an image of a bear eating pizza using Stable Diffusion with the IP-Adapter
189+
190+
```python
191+
from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
192+
import torch
193+
import gc
194+
from diffusers.utils import load_image
195+
from accelerate.utils import compute_module_sizes
196+
197+
base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
198+
num_inference_steps = 50
199+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
200+
prompt="bear eats pizza"
201+
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
202+
203+
pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
204+
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
205+
pipe_sd.set_ip_adapter_scale(0.6)
206+
pipe_sd.to("cuda")
207+
208+
generator = torch.Generator(device="cpu").manual_seed(33)
209+
out_sd = pipe_sd(
210+
prompt=prompt,
211+
negative_prompt=negative_prompt,
212+
ip_adapter_image=image,
213+
num_inference_steps=num_inference_steps,
214+
generator=generator,
215+
).images[0]
216+
```
217+
218+
let’s take a look at the image and also print out the memory used
219+
220+
<div class="flex justify-center">
221+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_0.png"/>
222+
</div>
223+
224+
```python
225+
def bytes_to_giga_bytes(bytes):
226+
return bytes / 1024 / 1024 / 1024
227+
print(
228+
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
229+
)
230+
```
231+
232+
```bash
233+
Max memory allocated: 4.406213283538818 GB
234+
```
235+
236+
Now, we can use `from_pipe` to switch to the SAG pipeline.
237+
238+
```python
239+
pipe_sag = StableDiffusionSAGPipeline.from_pipe(
240+
pipe_sd,
241+
)
242+
```
243+
244+
It already has IP-Adapter loaded so that you can pass the same bear image as `ip_adapter_image`
245+
246+
```python
247+
generator = torch.Generator(device="cpu").manual_seed(33)
248+
out_sag = pipe_sag(
249+
prompt = prompt,
250+
negative_prompt=negative_prompt,
251+
ip_adapter_image=image,
252+
num_inference_steps=num_inference_steps,
253+
generator=generator,
254+
guidance_scale=1.0,
255+
sag_scale=0.75).images[0]
256+
```
257+
258+
You can see a pretty nice improvement in the output
259+
260+
<div class="flex justify-center">
261+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sag_1.png"/>
262+
</div>
263+
264+
Now we have both `stableDiffusionPipeline` and `StableDiffusionSAGPipeline` co-existing with the same loaded model components; You can use them interchangeably without additional memory.
265+
266+
```
267+
print(
268+
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
269+
)
270+
```
271+
272+
```bash
273+
Max memory allocated: 4.406213283538818 GB
274+
```
275+
276+
Let's unload the IP adapter from the SAG pipeline. It's important to note that methods like `load_ip_adapter` and `unload_ip_adapter` modify the state of the model components. Therefore, when you use these methods on one pipeline, it will affect all other pipelines that share the same model components.
277+
278+
```bash
279+
pipe_sag.unload_ip_adapter()
280+
```
281+
282+
If you try to use the Stable Diffusion pipeline with IP adapter again, it will fail
283+
284+
```bash
285+
generator = torch.Generator(device="cpu").manual_seed(33)
286+
out_sd = pipe_sd(
287+
prompt=prompt,
288+
negative_prompt=negative_prompt,
289+
ip_adapter_image=image,
290+
num_inference_steps=num_inference_steps,
291+
generator=generator,
292+
).images[0]
293+
```
294+
295+
```bash
296+
AttributeError: 'NoneType' object has no attribute 'image_projection_layers'
297+
```
298+
299+
Please note that the pipeline methods may not function properly on a new pipeline created using the `from_pipe` method. For instance, the `enable_model_cpu_offload` method installs hooks to the model components based on a unique offloading sequence for each pipeline. Therefore, if the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
300+
301+
To ensure proper functionality, we recommend re-applying the pipeline methods on the new pipeline created using the `from_pipe` method.
302+
303+
You can also add or subtract model components when you create new pipelines. Let's now create a AnimateDiff pipeline with an additional `MotionAdapter` module
304+
305+
```bash
306+
from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
307+
from diffusers.utils import export_to_gif
308+
309+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
310+
311+
pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
312+
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
313+
# load ip_adapter again and load lora weights
314+
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
315+
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
316+
pipe_animate.to("cuda")
317+
318+
generator = torch.Generator(device="cpu").manual_seed(33)
319+
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
320+
out = pipe_animate(
321+
prompt= prompt,
322+
num_frames=16,
323+
num_inference_steps=num_inference_steps,
324+
ip_adapter_image = image,
325+
generator=generator,
326+
).frames[0]
327+
export_to_gif(out, "out_animate.gif")
328+
```
329+
<div class="flex justify-center">
330+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_animate_3.gif"/>
331+
</div>
332+
333+
334+
When creating multiple pipelines using the `from_pipe` method, it is important to note that the memory requirement will be determined by the pipeline with the highest memory usage. This means that regardless of the number of pipelines you create, the total memory requirement will always be the same as the highest memory requirement among the pipelines.
335+
336+
For example, we have created three pipelines - `stableDiffusionPipeline`, `StableDiffusionSAGPipeline`, and `AnimateDiffPipeline` - and the `AnimateDiffPipeline` has the highest memory requirement, then the total memory usage will be based on the memory requirement of the `AnimateDiffPipeline`.
337+
338+
Therefore, creating additional pipelines will not add up to the total memory requirement. Each pipeline can be used interchangeably without any additional memory overhead.
339+
340+
341+
Did you know that you can use `from_pipe` with a community pipeline? Let me show you an example of using long negative prompt and prompt weighting!
342+
343+
```bash
344+
pipe_lpw = DiffusionPipeline.from_pipe(
345+
pipe_sd,
346+
custom_pipeline="lpw_stable_diffusion",
347+
).to("cuda")
348+
349+
prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
350+
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
351+
generator = torch.Generator(device="cpu").manual_seed(33)
352+
out_lpw = pipe_lpw.text2img(
353+
prompt,
354+
negative_prompt=neg_prompt,
355+
width=512,height=512,
356+
max_embeddings_multiples=3,
357+
num_inference_steps=num_inference_steps,
358+
generator=generator,
359+
).images[0]
360+
```
361+
362+
<div class="flex justify-center">
363+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_lpw_4.png"/>
364+
</div>
365+
366+
let’s run StableDiffusionPipeline with the same inputs to compare: the result from the long prompt weighting pipeline is more aligned with the text prompt.
367+
368+
```
369+
generator = torch.Generator(device="cpu").manual_seed(33)
370+
out_sd = pipe_sd(
371+
prompt=prompt,
372+
negative_prompt=negative_prompt,
373+
generator=generator,
374+
num_inference_steps=num_inference_steps,
375+
).images[0]
376+
out_sd
377+
```
378+
<div class="flex justify-center">
379+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_5.png"/>
380+
</div>
381+
382+
383+
You can easily switch between different pipelines using the `from_pipe` method, similar to turning on and off a feature on your pipeline. To switch between tasks, you can use the `from_pipe` method with `AutoPipeline`, which automatically identifies the pipeline class based on the task. You can find more information about this feature at the [AutoPipe Guide](https://huggingface.co/docs/diffusers/tutorials/autopipeline).
384+
385+
182386
## Checkpoint variants
183387

184388
A checkpoint variant is usually a checkpoint whose weights are:

examples/community/lpw_stable_diffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ class StableDiffusionLongPromptWeightingPipeline(
439439
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
440440
"""
441441

442+
model_cpu_offload_seq = "text_encoder-->unet->vae"
442443
_optional_components = ["safety_checker", "feature_extractor"]
444+
_exclude_from_cpu_offload = ["safety_checker"]
443445

444446
def __init__(
445447
self,

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
import torch.utils.checkpoint
1919

20-
from ...configuration_utils import ConfigMixin, register_to_config
20+
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
2121
from ...loaders import UNet2DConditionLoadersMixin
2222
from ...utils import logging
2323
from ..attention_processor import (
@@ -393,8 +393,11 @@ def from_unet2d(
393393
):
394394
has_motion_adapter = motion_adapter is not None
395395

396+
if has_motion_adapter:
397+
motion_adapter.to(device=unet.device)
398+
396399
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
397-
config = unet.config
400+
config = dict(unet.config)
398401
config["_class_name"] = cls.__name__
399402

400403
down_blocks = []
@@ -427,6 +430,7 @@ def from_unet2d(
427430
if not config.get("num_attention_heads"):
428431
config["num_attention_heads"] = config["attention_head_dim"]
429432

433+
config = FrozenDict(config)
430434
model = cls.from_config(config)
431435

432436
if not load_weights:

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
vae: AutoencoderKL,
132132
text_encoder: CLIPTextModel,
133133
tokenizer: CLIPTokenizer,
134-
unet: UNet2DConditionModel,
134+
unet: Union[UNet2DConditionModel, UNetMotionModel],
135135
motion_adapter: MotionAdapter,
136136
scheduler: Union[
137137
DDIMScheduler,

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,39 @@ def get_class_obj_and_candidates(
292292
return class_obj, class_candidates
293293

294294

295+
def _get_custom_pipeline_class(
296+
custom_pipeline,
297+
repo_id=None,
298+
hub_revision=None,
299+
class_name=None,
300+
cache_dir=None,
301+
revision=None,
302+
):
303+
if custom_pipeline.endswith(".py"):
304+
path = Path(custom_pipeline)
305+
# decompose into folder & file
306+
file_name = path.name
307+
custom_pipeline = path.parent.absolute()
308+
elif repo_id is not None:
309+
file_name = f"{custom_pipeline}.py"
310+
custom_pipeline = repo_id
311+
else:
312+
file_name = CUSTOM_PIPELINE_FILE_NAME
313+
314+
if repo_id is not None and hub_revision is not None:
315+
# if we load the pipeline code from the Hub
316+
# make sure to overwrite the `revision`
317+
revision = hub_revision
318+
319+
return get_class_from_dynamic_module(
320+
custom_pipeline,
321+
module_file=file_name,
322+
class_name=class_name,
323+
cache_dir=cache_dir,
324+
revision=revision,
325+
)
326+
327+
295328
def _get_pipeline_class(
296329
class_obj,
297330
config=None,
@@ -304,25 +337,10 @@ def _get_pipeline_class(
304337
revision=None,
305338
):
306339
if custom_pipeline is not None:
307-
if custom_pipeline.endswith(".py"):
308-
path = Path(custom_pipeline)
309-
# decompose into folder & file
310-
file_name = path.name
311-
custom_pipeline = path.parent.absolute()
312-
elif repo_id is not None:
313-
file_name = f"{custom_pipeline}.py"
314-
custom_pipeline = repo_id
315-
else:
316-
file_name = CUSTOM_PIPELINE_FILE_NAME
317-
318-
if repo_id is not None and hub_revision is not None:
319-
# if we load the pipeline code from the Hub
320-
# make sure to overwrite the `revision`
321-
revision = hub_revision
322-
323-
return get_class_from_dynamic_module(
340+
return _get_custom_pipeline_class(
324341
custom_pipeline,
325-
module_file=file_name,
342+
repo_id=repo_id,
343+
hub_revision=hub_revision,
326344
class_name=class_name,
327345
cache_dir=cache_dir,
328346
revision=revision,

0 commit comments

Comments
 (0)