You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/memory.md
+33Lines changed: 33 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -158,6 +158,39 @@ In order to properly offload models after they're called, it is required to run
158
158
159
159
</Tip>
160
160
161
+
## FP8 layerwise weight-casting
162
+
163
+
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass.
164
+
165
+
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately.
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
178
+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
179
+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
180
+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
181
+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
182
+
"atmosphere of this unique musical performance."
183
+
)
184
+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
185
+
export_to_video(video, "output.mp4", fps=8)
186
+
```
187
+
188
+
In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`].
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
0 commit comments