-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Layerwise Upcasting #10347
[core] Layerwise Upcasting #10347
Conversation
Co-Authored-By: Dhruv Nair <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice start 👍🏽 A few things to consider here
It is difficult to maintain a large global list of supported ops and can lead to us either missing modules or not applying upcasting in cases where it can be used.
So any kind of layerwise casting on these modules runs into an error because the parameters remain in a lower memory dtype unless the entire module is upcast. The initial PR got around this by adding the
This implementation seems to do something similar using the global
|
Layerwise upcasting of text encoders are possible by leveraging
Codeimport gc
import torch
from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from transformers.models.t5.modeling_t5 import T5DenseGatedActDense
set_verbosity_debug()
def main(apply_layerwise_upcasting_text_encoder: bool = False):
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.transformer.enable_layerwise_upcasting(
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["patch_embed", "norm"]
)
if apply_layerwise_upcasting_text_encoder:
for name, module in pipe.text_encoder.named_modules():
if isinstance(module, T5DenseGatedActDense):
module.forward = T5DenseGatedActDense_forward.__get__(module)
pipe.text_encoder = apply_layerwise_upcasting(
pipe.text_encoder,
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["norm"]
)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
dtype=torch.bfloat16,
)
video = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(42)
).frames[0]
export_to_video(video, "output.mp4", fps=8)
print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
# If we don't overwrite the forward method of T5DenseGatedActDense, the original forward would downcast hidden_states
# to torch.float8_e4m3fn, which would cause an error.
# Line in question: https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/t5/modeling_t5.py#L292
def T5DenseGatedActDense_forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
main()
# Without text encoder fp8 layerwise upcasting:
# Model memory: 15.228 GB
# Inference memory: 30.010 GB
# With text encoder fp8 layerwise upcasting:
# Model memory: 10.915 GB
# Inference memory: 25.705 GB When layerwise upcasting is not enable in T5, the memory required is about 30 GB. When enabled, the memory usage is about 25.7 GB.
I'm not sure how to get around this easily. We could probably use a simple context manager in the hook implementations to disable any tensor casts in the internal model forwards but it seems a bit too hacky to me :/ Open to suggestions class DisableTensorTo:
def __enter__(self):
self.original_to = torch.Tensor.to
def noop_to(self, *args, **kwargs):
return self
torch.Tensor.to = noop_to
def __exit__(self, exc_type, exc_value, traceback):
torch.Tensor.to = self.original_to |
For the reasons mentioned above, we are unable to run LoRA inference either. You can overwrite the forward methods for one (to get it to work without too much thinking), but maybe the context manager solution for disabling There are actually two problems that need to be dealt with for LoRA. Whether you load the lora weights before or after enabling layerwise upcasting, they will be loaded in the correct dtype same as the transformer (
Codeimport gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear
set_verbosity_debug()
def main():
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
pipe.set_adapters(["cogvideox-lora"], [1.0])
pipe.transformer.enable_layerwise_upcasting(
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["patch_embed", "norm"]
)
# Post layerwise upcasting does load lora weights in torch.float8_e4m3fn but due to no hooks, it errors during inference
# pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
# pipe.set_adapters(["cogvideox-lora"], [1.0])
for name, parameter in pipe.transformer.named_parameters():
if "lora" in name:
assert(parameter.dtype == torch.float8_e4m3fn)
LoRALinear.forward = LoRALinear_forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
dtype=torch.bfloat16,
)
video = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(42)
).frames[0]
export_to_video(video, "output.mp4", fps=8)
print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
from typing import Any
def LoRALinear_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
# x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
result = result.to(torch_result_dtype)
return result
main()
# Model memory: 15.351 GB
# Inference memory: 30.134 GB cogvideox-layerwise-lora.mp4 |
Just documenting for now because I don't know what approach we're going to take to deal with these kinds of problems yet. If we enable layerwise upcasting for something like HunyuanVideo, the prompt_embeds and prompt_attention_mask are moved to fp8 dtype, and the latent input passed into the model is also fp8. This is because of lines like this:
One solution would be to overwrite the Edit: Ah, so the reason why it worked in the benchmark script is because the |
…ets because .dtype on ModelMixin should be able to handle fp8 weight case
def test_layerwise_upcasting(storage_dtype, compute_dtype): | ||
torch.manual_seed(0) | ||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is required because in the model tests, we create the inputs in torch.float32
. When the compute_dtype
is different, the inputs (such as hidden_states or encoder_hidden_states) must be casted appropriately. It is done this way so that inputs such as timestep
, which is supposed to be torch.int32
or torch.int64
, is not casted to compute_dtype
(as the invoked function checks if the castee tensor's dtype is torch.float32
already)
@@ -1787,7 +1787,7 @@ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embeddi | |||
def forward(self, timestep, caption_feat, caption_mask): | |||
# timestep embedding: | |||
time_freq = self.time_proj(timestep) | |||
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) | |||
time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 This change is needed for the Lumina test to not fail. I think it should be a safe change
…rns based on feedback
@unittest.skip( | ||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " | ||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" | ||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" | ||
"2. Unskip this test." | ||
) | ||
def test_layerwise_casting_inference(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the tests that don't pass, it is due to the following reasons:
- PyTorch has certain operations unimplemented when deterministic algorithms are enabled and float8 types are used.
- One of our Autoencoder implementations (AutoencoderOobleck) uses the deprecated
torch.nn.utils.weight_norm
. From some light debugging, it seems likenn::Module::to
is a null-op for some reason and the weight types don't change - The forward pass of
AutoencoderTiny
ends up casting the module inputs totorch.float32
when usingcompute_dtype=torch.bfloat16
. It is happening somewhere in thenn.Sequential
's from some light debugging, but due to potentially low usage, I've skipped the test for now with instructions on how to fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xfail
might be a better option here sinch you mentioned it might be fixed in a future PyTorch release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reference:
Line 1528 in a1f9a71
@pytest.mark.xfail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
…ntly unimplemented operation support
Co-authored-by: Dhruv Nair <[email protected]>
…-fp32 comparison (required for a few models' test to pass)
Thanks for the reviews everyone 🤗 |
Thanks for shipping this banger of a feature, Aryan! |
[...continuation of #9177]
Pytorch has had support for
float8_e4m3fn
andfloat8_e5m2
as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.Code
Flux visual results
CogVideoX visual results
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Hunyuan Video visual results
hunyuan_video---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Mochi visual results
mochi---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Assumptions made so far:
compute_dtype
storage_dtype
.Why is there no memory savings in the initial load memory?
We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype
Why a different approach from #9177?
While providing the API to use this via
ModelMixin
is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent ofModelMixin
allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.Not opposed to the idea of having
enable_layerwise_upcasting
inModelMixin
, but let's do it in a way that does not impose any restrictions on how it's possible to use it.Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like
skip_modules_pattern
andskip_modules_classes
helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.Fixes #9949
cc @vladmandic @asomoza
TODOs:
ExploreNo real impact on time unless weight casting is combined with device castingnon_blocking
and cuda streams for overlapping weight casting with computationTest tensor caching in lower precision for methods like [core] Pyramid Attention Broadcast #9562 and [core] FasterCache #10163Nice reading material for the interested: