Skip to content

Commit 0ab63ff

Browse files
authored
Fix CPU Offloading Usage & Typos (huggingface#8230)
* Fix typos * Fix `pipe.enable_model_cpu_offload()` usage * Fix cpu offloading * Update numbers
1 parent db33af0 commit 0ab63ff

File tree

11 files changed

+56
-60
lines changed

11 files changed

+56
-60
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
7777

7878
## Quickstart
7979

80-
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 22000+ checkpoints):
80+
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 25.000+ checkpoints):
8181

8282
```python
8383
from diffusers import DiffusionPipeline
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
219219
- https://github.com/deep-floyd/IF
220220
- https://github.com/bentoml/BentoML
221221
- https://github.com/bmaltais/kohya_ss
222-
- +9000 other amazing GitHub repositories 💪
222+
- +11.000 other amazing GitHub repositories 💪
223223

224224
Thank you for using us ❤️.
225225

docs/source/en/optimization/tgate.md

+12-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Before you begin, make sure you install T-GATE.
66

77
```bash
88
pip install tgate
9-
pip install -U pytorch diffusers transformers accelerate DeepCache
9+
pip install -U torch diffusers transformers accelerate DeepCache
1010
```
1111

1212

@@ -46,12 +46,12 @@ pipe = TgatePixArtLoader(
4646

4747
image = pipe.tgate(
4848
"An alpaca made of colorful building blocks, cyberpunk.",
49-
gate_step=gate_step,
49+
gate_step=gate_step,
5050
num_inference_steps=inference_step,
5151
).images[0]
5252
```
5353
</hfoption>
54-
<hfoption id="Stable Diffusion XL">
54+
<hfoption id="Stable Diffusion XL">
5555

5656
Accelerate `StableDiffusionXLPipeline` with T-GATE:
5757

@@ -78,9 +78,9 @@ pipe = TgateSDXLLoader(
7878
).to("cuda")
7979

8080
image = pipe.tgate(
81-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
82-
gate_step=gate_step,
83-
num_inference_steps=inference_step
81+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
82+
gate_step=gate_step,
83+
num_inference_steps=inference_step
8484
).images[0]
8585
```
8686
</hfoption>
@@ -111,9 +111,9 @@ pipe = TgateSDXLDeepCacheLoader(
111111
).to("cuda")
112112

113113
image = pipe.tgate(
114-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
115-
gate_step=gate_step,
116-
num_inference_steps=inference_step
114+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
115+
gate_step=gate_step,
116+
num_inference_steps=inference_step
117117
).images[0]
118118
```
119119
</hfoption>
@@ -151,9 +151,9 @@ pipe = TgateSDXLLoader(
151151
).to("cuda")
152152

153153
image = pipe.tgate(
154-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
155-
gate_step=gate_step,
156-
num_inference_steps=inference_step
154+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
155+
gate_step=gate_step,
156+
num_inference_steps=inference_step
157157
).images[0]
158158
```
159159
</hfoption>

docs/source/en/using-diffusers/inference_with_tcd_lora.md

+25-25
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ image = pipe(
7878
prompt=prompt,
7979
num_inference_steps=4,
8080
guidance_scale=0,
81-
eta=0.3,
81+
eta=0.3,
8282
generator=torch.Generator(device=device).manual_seed(0),
8383
).images[0]
8484
```
@@ -156,22 +156,22 @@ image = pipe(
156156
prompt=prompt,
157157
num_inference_steps=8,
158158
guidance_scale=0,
159-
eta=0.3,
159+
eta=0.3,
160160
generator=torch.Generator(device=device).manual_seed(0),
161161
).images[0]
162162
```
163163

164164
![](https://github.com/jabir-zheng/TCD/raw/main/assets/animagine_xl.png)
165165

166-
TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.
166+
TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.
167167

168168
> [!TIP]
169169
> Check out the [Merge LoRAs](merge_loras) guide to learn more about efficient merging methods.
170170
171171
```python
172172
import torch
173173
from diffusers import StableDiffusionXLPipeline
174-
from scheduling_tcd import TCDScheduler
174+
from scheduling_tcd import TCDScheduler
175175

176176
device = "cuda"
177177
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -191,7 +191,7 @@ image = pipe(
191191
prompt=prompt,
192192
num_inference_steps=4,
193193
guidance_scale=0,
194-
eta=0.3,
194+
eta=0.3,
195195
generator=torch.Generator(device=device).manual_seed(0),
196196
).images[0]
197197
```
@@ -215,7 +215,7 @@ from PIL import Image
215215
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
216216
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
217217
from diffusers.utils import load_image, make_image_grid
218-
from scheduling_tcd import TCDScheduler
218+
from scheduling_tcd import TCDScheduler
219219

220220
device = "cuda"
221221
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
@@ -249,13 +249,13 @@ controlnet = ControlNetModel.from_pretrained(
249249
controlnet_id,
250250
torch_dtype=torch.float16,
251251
variant="fp16",
252-
).to(device)
252+
)
253253
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
254254
base_model_id,
255255
controlnet=controlnet,
256256
torch_dtype=torch.float16,
257257
variant="fp16",
258-
).to(device)
258+
)
259259
pipe.enable_model_cpu_offload()
260260

261261
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
@@ -271,9 +271,9 @@ depth_image = get_depth_map(image)
271271
controlnet_conditioning_scale = 0.5 # recommended for good generalization
272272

273273
image = pipe(
274-
prompt,
275-
image=depth_image,
276-
num_inference_steps=4,
274+
prompt,
275+
image=depth_image,
276+
num_inference_steps=4,
277277
guidance_scale=0,
278278
eta=0.3,
279279
controlnet_conditioning_scale=controlnet_conditioning_scale,
@@ -290,7 +290,7 @@ grid_image = make_image_grid([depth_image, image], rows=1, cols=2)
290290
import torch
291291
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
292292
from diffusers.utils import load_image, make_image_grid
293-
from scheduling_tcd import TCDScheduler
293+
from scheduling_tcd import TCDScheduler
294294

295295
device = "cuda"
296296
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -301,13 +301,13 @@ controlnet = ControlNetModel.from_pretrained(
301301
controlnet_id,
302302
torch_dtype=torch.float16,
303303
variant="fp16",
304-
).to(device)
304+
)
305305
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
306306
base_model_id,
307307
controlnet=controlnet,
308308
torch_dtype=torch.float16,
309309
variant="fp16",
310-
).to(device)
310+
)
311311
pipe.enable_model_cpu_offload()
312312

313313
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
@@ -322,9 +322,9 @@ canny_image = load_image("https://huggingface.co/datasets/hf-internal-testing/di
322322
controlnet_conditioning_scale = 0.5 # recommended for good generalization
323323

324324
image = pipe(
325-
prompt,
326-
image=canny_image,
327-
num_inference_steps=4,
325+
prompt,
326+
image=canny_image,
327+
num_inference_steps=4,
328328
guidance_scale=0,
329329
eta=0.3,
330330
controlnet_conditioning_scale=controlnet_conditioning_scale,
@@ -336,7 +336,7 @@ grid_image = make_image_grid([canny_image, image], rows=1, cols=2)
336336
![](https://github.com/jabir-zheng/TCD/raw/main/assets/controlnet_canny_tcd.png)
337337

338338
<Tip>
339-
The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
339+
The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
340340
</Tip>
341341

342342
</hfoption>
@@ -350,7 +350,7 @@ from diffusers import StableDiffusionXLPipeline
350350
from diffusers.utils import load_image, make_image_grid
351351

352352
from ip_adapter import IPAdapterXL
353-
from scheduling_tcd import TCDScheduler
353+
from scheduling_tcd import TCDScheduler
354354

355355
device = "cuda"
356356
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -359,8 +359,8 @@ ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
359359
tcd_lora_id = "h1t/TCD-SDXL-LoRA"
360360

361361
pipe = StableDiffusionXLPipeline.from_pretrained(
362-
base_model_path,
363-
torch_dtype=torch.float16,
362+
base_model_path,
363+
torch_dtype=torch.float16,
364364
variant="fp16"
365365
)
366366
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
@@ -375,13 +375,13 @@ ref_image = load_image("https://raw.githubusercontent.com/tencent-ailab/IP-Adapt
375375
prompt = "best quality, high quality, wearing sunglasses"
376376

377377
image = ip_model.generate(
378-
pil_image=ref_image,
378+
pil_image=ref_image,
379379
prompt=prompt,
380380
scale=0.5,
381-
num_samples=1,
382-
num_inference_steps=4,
381+
num_samples=1,
382+
num_inference_steps=4,
383383
guidance_scale=0,
384-
eta=0.3,
384+
eta=0.3,
385385
seed=0,
386386
)[0]
387387

docs/source/en/using-diffusers/inpaint.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ from diffusers.utils import load_image, make_image_grid
230230

231231
pipeline = AutoPipelineForInpainting.from_pretrained(
232232
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
233-
).to("cuda")
233+
)
234234
pipeline.enable_model_cpu_offload()
235235
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
236236
pipeline.enable_xformers_memory_efficient_attention()
@@ -255,7 +255,7 @@ from diffusers.utils import load_image, make_image_grid
255255

256256
pipeline = AutoPipelineForInpainting.from_pretrained(
257257
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
258-
).to("cuda")
258+
)
259259
pipeline.enable_model_cpu_offload()
260260
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
261261
pipeline.enable_xformers_memory_efficient_attention()
@@ -296,7 +296,7 @@ from diffusers.utils import load_image, make_image_grid
296296

297297
pipeline = AutoPipelineForInpainting.from_pretrained(
298298
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
299-
).to("cuda")
299+
)
300300
pipeline.enable_model_cpu_offload()
301301
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
302302
pipeline.enable_xformers_memory_efficient_attention()
@@ -319,7 +319,7 @@ from diffusers.utils import load_image, make_image_grid
319319

320320
pipeline = AutoPipelineForInpainting.from_pretrained(
321321
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
322-
).to("cuda")
322+
)
323323
pipeline.enable_model_cpu_offload()
324324
# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
325325
pipeline.enable_xformers_memory_efficient_attention()

examples/community/README.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,12 @@ pipeline_output = pipe(
240240
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
241241
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
242242
# ------------------------------------------------
243-
243+
244244
# ----- recommended setting for LCM version ------
245245
# denoising_steps=4,
246246
# ensemble_size=5,
247247
# -------------------------------------------------
248-
248+
249249
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
250250
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
251251
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
@@ -1032,7 +1032,7 @@ image = pipe().images[0]
10321032

10331033
Make sure you have @crowsonkb's <https://github.com/crowsonkb/k-diffusion> installed:
10341034

1035-
```
1035+
```sh
10361036
pip install k-diffusion
10371037
```
10381038

@@ -1854,13 +1854,13 @@ To use this pipeline, you need to:
18541854

18551855
You can simply use pip to install IPEX with the latest version.
18561856

1857-
```python
1857+
```sh
18581858
python -m pip install intel_extension_for_pytorch
18591859
```
18601860

18611861
**Note:** To install a specific version, run with the following command:
18621862

1863-
```
1863+
```sh
18641864
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
18651865
```
18661866

@@ -1958,13 +1958,13 @@ To use this pipeline, you need to:
19581958

19591959
You can simply use pip to install IPEX with the latest version.
19601960

1961-
```python
1961+
```sh
19621962
python -m pip install intel_extension_for_pytorch
19631963
```
19641964

19651965
**Note:** To install a specific version, run with the following command:
19661966

1967-
```
1967+
```sh
19681968
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
19691969
```
19701970

@@ -3010,8 +3010,8 @@ This code implements a pipeline for the Stable Diffusion model, enabling the div
30103010

30113011
### Sample Code
30123012

3013-
```
3014-
from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
3013+
```py
3014+
from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
30153015
pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
30163016

30173017
rp_args = {
@@ -4131,7 +4131,7 @@ This implementation is based on [Diffusers](https://huggingface.co/docs/diffuser
41314131

41324132
## Example Usage
41334133

4134-
```
4134+
```py
41354135
import os
41364136
import torch
41374137

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

-1
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,6 @@ def collate_fn(examples):
896896
images = []
897897
if args.validation_prompts is not None:
898898
logger.info("Running inference for collecting generated images...")
899-
pipeline = pipeline.to(accelerator.device)
900899
pipeline.torch_dtype = weight_dtype
901900
pipeline.set_progress_bar_config(disable=True)
902901
pipeline.enable_model_cpu_offload()

tests/lora/test_lora_layers_sd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def test_sd_load_civitai_empty_network_alpha(self):
642642
This test simply checks that loading a LoRA with an empty network alpha works fine
643643
See: https://github.com/huggingface/diffusers/issues/5606
644644
"""
645-
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(torch_device)
645+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
646646
pipeline.enable_sequential_cpu_offload()
647647
civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
648648
pipeline.load_lora_weights(civitai_path, adapter_name="ahri")

tests/pipelines/i2vgen_xl/test_i2vgenxl.py

-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ def tearDown(self):
243243

244244
def test_i2vgen_xl(self):
245245
pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
246-
pipe = pipe.to(torch_device)
247246
pipe.enable_model_cpu_offload()
248247
pipe.set_progress_bar_config(disable=None)
249248
image = load_image(

tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,10 @@ def test_ip_adapter_multiple_masks(self):
612612
def test_instant_style_multiple_masks(self):
613613
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
614614
"h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
615-
).to("cuda")
615+
)
616616
pipeline = StableDiffusionXLPipeline.from_pretrained(
617617
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
618-
).to("cuda")
618+
)
619619
pipeline.enable_model_cpu_offload()
620620

621621
pipeline.load_ip_adapter(

0 commit comments

Comments
 (0)