diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index a2150f9aa0b7..e64db46e2769 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -158,6 +158,39 @@ In order to properly offload models after they're called, it is required to run +## FP8 layerwise weight-casting + +PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass. + +Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately. + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`]. + +[[autodoc]] ModelMixin.enable_layerwise_upcasting + +[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting + ## Channels-last memory format The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.