Skip to content

Commit 7037133

Browse files
committed
update doc page
1 parent 64e6c9c commit 7037133

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

docs/source/en/optimization/memory.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,39 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## FP8 layerwise weight-casting
162+
163+
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass.
164+
165+
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately.
166+
167+
```python
168+
import torch
169+
from diffusers import CogVideoXPipeline
170+
from diffusers.utils import export_to_video
171+
172+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
173+
pipe.to("cuda")
174+
pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
175+
176+
prompt = (
177+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
178+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
179+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
180+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
181+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
182+
"atmosphere of this unique musical performance."
183+
)
184+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
185+
export_to_video(video, "output.mp4", fps=8)
186+
```
187+
188+
In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`].
189+
190+
[[autodoc]] ModelMixin.enable_layerwise_upcasting
191+
192+
[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting
193+
161194
## Channels-last memory format
162195

163196
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.

0 commit comments

Comments
 (0)