Skip to content

Commit 390742b

Browse files
a-r-r-o-wstevhliu
andauthored
Apply suggestions from code review
Co-authored-by: Steven Liu <[email protected]>
1 parent f1b46d6 commit 390742b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

docs/source/en/optimization/memory.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ In order to properly offload models after they're called, it is required to run
160160

161161
## FP8 layerwise weight-casting
162162

163-
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass.
163+
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
164164

165-
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately.
165+
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
166166

167167
```python
168168
import torch
@@ -185,7 +185,9 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
185185
export_to_video(video, "output.mp4", fps=8)
186186
```
187187

188-
In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`].
188+
In the above example, layerwise upcasting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
189+
190+
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`].
189191

190192
[[autodoc]] ModelMixin.enable_layerwise_upcasting
191193

0 commit comments

Comments
 (0)