Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Layerwise Upcasting #10347

Merged
merged 55 commits into from
Jan 22, 2025
Merged

[core] Layerwise Upcasting #10347

merged 55 commits into from
Jan 22, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Dec 23, 2024

[...continuation of #9177]

Pytorch has had support for float8_e4m3fn and float8_e5m2 as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.

Code
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import AllegroPipeline, CogVideoXPipeline, LattePipeline, FluxPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, MochiPipeline, LTXPipeline
from diffusers.utils import export_to_video, load_image
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": (
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "a cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "a-r-r-o-w/LTX-Video-diffusers"
    cache_dir = None

    pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt": "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
        "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
        "width": 768,
        "height": 512,
        "num_frames": 161,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
    latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
    latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio

    latents = pipe._unpack_latents(
        latents,
        latent_num_frames,
        latent_height,
        latent_width,
        pipe.transformer_spatial_patch_size,
        pipe.transformer_temporal_patch_size,
    )
    latents = pipe._denormalize_latents(
        latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
    )
    latents = latents.to(pipe.vae.dtype)

    timestep = None
    video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=24)
    return filename


def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "decode": decode_hunyuan_video,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "decode": decode_mochi,
    },
    "ltx_video": {
        "prepare": prepare_ltx_video,
        "decode": decode_ltx_video,
    },
}

STR_TO_DTYPE = {
    "float8_e4m3fn": torch.float8_e4m3fn,
    "float8_e5m2": torch.float8_e5m2,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
    "float32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator("cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(
    model_id: str, apply_layerwise_upcasting: str, output_dir: str, storage_dtype: str, compute_dtype: str, compile: bool = False
):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    pytorch_storage_dtype = STR_TO_DTYPE[storage_dtype]
    pytorch_compute_dtype = STR_TO_DTYPE[compute_dtype]
    model = MODEL_MAPPING[model_id]

    try:
        clean_memory()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=pytorch_compute_dtype, compile=compile)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)

        # 2. Apply layerwise upcasting technique
        if apply_layerwise_upcasting:
            pipe.transformer.enable_layerwise_upcasting(
                storage_dtype=pytorch_storage_dtype,
                compute_dtype=pytorch_compute_dtype,
                skip_modules_pattern=["pos_embed", "patch_embed", "norm"],
            )

        downcast_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        clean_memory()
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = (
            output_dir
            / f"{model_id}---storage_dtype-{storage_dtype}---compute_dtype-{compute_dtype}---compile-{compile}"
        )
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            num_frames=generation_kwargs.get("num_frames", None),
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "upcasting": apply_layerwise_upcasting,
            "time": time,
            "initial_memory": model_memory,
            "model_memory": downcast_memory,
            "inference_memory": inference_memory,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "upcasting": apply_layerwise_upcasting,
            "time": None,
            "initial_memory": None,
            "model_memory": None,
            "inference_memory": None,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--apply_layerwise_upcasting",
        action="store_true",
        help="Whether to apply layerwise upcasting to the transformer.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument(
        "--storage_dtype",
        type=str,
        choices=["float8_e4m3fn", "float8_e5m2", "bfloat16", "float16", "float32"],
        help="Storage torch.dtype to use for transformer",
    )
    parser.add_argument(
        "--compute_dtype",
        type=str,
        choices=["bfloat16", "float16", "float32"],
        help="Compute torch.dtype to use for transformer",
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.apply_layerwise_upcasting, args.output_dir, args.storage_dtype, args.compute_dtype, args.compile)
model_id granularity storage_dtype time initial_memory model_max_memory inference_max_memory
flux none bfloat16 16.939 31.438 31.447 32.02
flux diffusers_model float8_e4m3fn 23.866 31.438 21.178 33.963
flux diffusers_layer float8_e4m3fn 18.125 31.438 28.779 29.291
flux pytorch_layer float8_e4m3fn 20.339 31.438 24.449 24.945
flux diffusers_model float8_e5m2 22.097 31.438 21.18 33.949
flux diffusers_layer float8_e5m2 18.013 31.438 28.797 29.309
flux pytorch_layer float8_e5m2 20.084 31.44 24.451 24.947
cogvideox-1.0 none bfloat16 244.255 19.661 19.678 24.426
cogvideox-1.0 diffusers_model float8_e4m3fn 243.65 19.661 14.531 25.217
cogvideox-1.0 diffusers_layer float8_e4m3fn 243.541 19.66 16.76 21.469
cogvideox-1.0 pytorch_layer float8_e4m3fn 243.346 19.661 15.281 19.992
cogvideox-1.0 diffusers_model float8_e5m2 243.899 19.661 14.531 25.217
cogvideox-1.0 diffusers_layer float8_e5m2 243.182 19.661 16.76 21.469
cogvideox-1.0 pytorch_layer float8_e5m2 243.136 19.661 15.281 19.992
hunyuan_video none bfloat16 71.748 38.584 38.613 41.141
hunyuan_video diffusers_layer float8_e4m3fn 71.933 38.574 35.904 38.314
hunyuan_video pytorch_layer float8_e4m3fn 72.869 38.573 31.33 33.719
latte none bfloat16 27.986 11.005 11.314 12.471
latte diffusers_layer float8_e4m3fn 27.921 11.005 10.75 11.889
latte pytorch_layer float8_e4m3fn 28.079 11.005 10.879 12.018
mochi none bfloat16 431.799 28.411 28.648 36.059
mochi diffusers_layer float8_e4m3fn 432.142 28.411 24.424 31.934
mochi pytorch_layer float8_e4m3fn 431.947 28.411 21.988 29.441
Flux visual results
Baseline
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
CogVideoX visual results
Baseline
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Hunyuan Video visual results
Baseline
hunyuan_video---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_layer-float8_e4m3
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Mochi visual results
Baseline
mochi---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_layer-float8_e4m3
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4

Assumptions made so far:

  • The input to the models with a hook are not casted, and are expected to already be in compute_dtype
  • Weight casting learned parameters of normalization layers can lead to poor quality as we've seen in the past few integrations. By default, layers for normalization and modulation are not downcasted to storage_dtype.
  • Sensible default names to avoid embedding, normalization and modulation layers. This is still configurable so users can choose to typecast them if they want.

Why is there no memory savings in the initial load memory?

We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype


Why a different approach from #9177?

While providing the API to use this via ModelMixin is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent of ModelMixin allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.

Not opposed to the idea of having enable_layerwise_upcasting in ModelMixin, but let's do it in a way that does not impose any restrictions on how it's possible to use it.

Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like skip_modules_pattern and skip_modules_classes helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.


Fixes #9949

cc @vladmandic @asomoza

TODOs:

  • Explore non_blocking and cuda streams for overlapping weight casting with computation No real impact on time unless weight casting is combined with device casting
  • Try to make torch compile work (edit: works if we increase the cache_size_limit but still recompiles multiple times)
  • Test with LoRAs
  • Test with training
  • Test tensor caching in lower precision for methods like [core] Pyramid Attention Broadcast #9562 and [core] FasterCache #10163
  • Tests
  • Docs

Nice reading material for the interested:

@a-r-r-o-w a-r-r-o-w requested review from DN6, sayakpaul and hlky December 23, 2024 00:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Dec 27, 2024

Nice start 👍🏽 A few things to consider here

  1. Rather than relying on standard names for keys to ignore for upcasting, it would be better to have them be class attributes in the models themselves. e.g something like this
    _always_upcast_modules = ["Decoder"]

It is difficult to maintain a large global list of supported ops and can lead to us either missing modules or not applying upcasting in cases where it can be used.

  1. Upcasting should also account for _keep_in_fp32_modules the way we do with quantization.

  2. There are model components that have casting operations internally such as:

    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

So any kind of layerwise casting on these modules runs into an error because the parameters remain in a lower memory dtype unless the entire module is upcast. The initial PR got around this by adding the _always_upcast_modules attribute that would apply the hook to the top level module instead of the individual layers.

if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:

This implementation seems to do something similar using the global _SUPPORTED_DIFFUSERS_LAYERS list, but this should also be a class attribute IMO.

  1. It's fine to add the hooks via the functions defined here, but enabling and disabling upcasting should be done through the ModelMixin IMO. If users want to apply them to other models we can include a section in the docs about importing the relevant functions from the hooks module.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Dec 30, 2024
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 4, 2025

Layerwise upcasting of text encoders are possible by leveraging apply_layerwise_upcasting. Tested with many different existing models, as well as ComfyUI nodes and it seems to work really well unless:

  • There is dtype casting of weights within the original forward (seems to be uncommon). Since we overwrite the forward method ourselves, any casting in the original forward is going to be very problematic to deal with. This is the case with T5Encoder from transformers, so I'm not quite sure how to deal with it without workarounds like in the example below.
  • There are model_weight-based casting of input tensors. This is used a lot in PEFT and requires workarounds too.
Code
import gc
import torch
from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from transformers.models.t5.modeling_t5 import T5DenseGatedActDense

set_verbosity_debug()


def main(apply_layerwise_upcasting_text_encoder: bool = False):
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.transformer.enable_layerwise_upcasting(
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        granularity="pytorch_layer",
        skip_modules_pattern=["patch_embed", "norm"]
    )

    if apply_layerwise_upcasting_text_encoder:
        for name, module in pipe.text_encoder.named_modules():
            if isinstance(module, T5DenseGatedActDense):
                module.forward = T5DenseGatedActDense_forward.__get__(module)

        pipe.text_encoder = apply_layerwise_upcasting(
            pipe.text_encoder,
            storage_dtype=torch.float8_e4m3fn,
            compute_dtype=torch.bfloat16,
            granularity="pytorch_layer",
            skip_modules_pattern=["norm"]
        )

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = (
        "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        "atmosphere of this unique musical performance."
    )

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")


# If we don't overwrite the forward method of T5DenseGatedActDense, the original forward would downcast hidden_states
# to torch.float8_e4m3fn, which would cause an error.
# Line in question: https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/t5/modeling_t5.py#L292
def T5DenseGatedActDense_forward(self, hidden_states):
    hidden_gelu = self.act(self.wi_0(hidden_states))
    hidden_linear = self.wi_1(hidden_states)
    hidden_states = hidden_gelu * hidden_linear
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.wo(hidden_states)
    return hidden_states


main()

# Without text encoder fp8 layerwise upcasting:
# Model memory: 15.228 GB
# Inference memory: 30.010 GB

# With text encoder fp8 layerwise upcasting:
# Model memory: 10.915 GB
# Inference memory: 25.705 GB

When layerwise upcasting is not enable in T5, the memory required is about 30 GB. When enabled, the memory usage is about 25.7 GB.

Without text encoder fp8 With text encoder fp8
cogvideox-layerwise-without-text-encoder.mp4
cogvideox-layerwise-with-text-encoder.mp4

I'm not sure how to get around this easily. We could probably use a simple context manager in the hook implementations to disable any tensor casts in the internal model forwards but it seems a bit too hacky to me :/ Open to suggestions

class DisableTensorTo:
    def __enter__(self):
        self.original_to = torch.Tensor.to
        
        def noop_to(self, *args, **kwargs):
            return self
    
        torch.Tensor.to = noop_to
    
    def __exit__(self, exc_type, exc_value, traceback):
        torch.Tensor.to = self.original_to

@a-r-r-o-w
Copy link
Member Author

For the reasons mentioned above, we are unable to run LoRA inference either. peft.tuners.lora.layer::Linear casts the model input based on the weight dtype, which is problematic too. Our hooks assumes the input to already be in compute_dtype, but peft will cast it to lower dtype here, which will be lossy cast (if we were to align input dtypes ourselves in the hook forward)

You can overwrite the forward methods for one (to get it to work without too much thinking), but maybe the context manager solution for disabling torch.Tensor::to works here as well :/

There are actually two problems that need to be dealt with for LoRA. Whether you load the lora weights before or after enabling layerwise upcasting, they will be loaded in the correct dtype same as the transformer (torch.float8_e4m3fn for example). This is good because we don't have to add any additional code to handle this. But:

  • if we load lora before enabling layerwise upcasting, then all lora linears get the upcasting hook attached, which is what we want. All is well in this case, except that input tensors are casted to float8 in peft.
  • if we load lora after enabling layerwise upcasting, then the lora linears don't get a hook attached (with the current implementation). This leaves the weights in float8 always. Open to suggestions on how we want to go about this - do I add some logic in load_lora_weights to check if the model already has upcasting hooks, and if so attach to lora layers as well?
Code
import gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear

set_verbosity_debug()


def main():
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    pipe.transformer.enable_layerwise_upcasting(
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        granularity="pytorch_layer",
        skip_modules_pattern=["patch_embed", "norm"]
    )

    # Post layerwise upcasting does load lora weights in torch.float8_e4m3fn but due to no hooks, it errors during inference
    # pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    # pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    for name, parameter in pipe.transformer.named_parameters():
        if "lora" in name:
            assert(parameter.dtype == torch.float8_e4m3fn)
    
    LoRALinear.forward = LoRALinear_forward

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")


from typing import Any

def LoRALinear_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
    self._check_forward_args(x, *args, **kwargs)
    adapter_names = kwargs.pop("adapter_names", None)

    if self.disable_adapters:
        if self.merged:
            self.unmerge()
        result = self.base_layer(x, *args, **kwargs)
    elif adapter_names is not None:
        result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
    elif self.merged:
        result = self.base_layer(x, *args, **kwargs)
    else:
        result = self.base_layer(x, *args, **kwargs)
        torch_result_dtype = result.dtype
        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A.keys():
                continue
            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]
            # x = x.to(lora_A.weight.dtype)

            if not self.use_dora[active_adapter]:
                result = result + lora_B(lora_A(dropout(x))) * scaling
            else:
                if isinstance(dropout, torch.nn.Identity) or not self.training:
                    base_result = result
                else:
                    x = dropout(x)
                    base_result = None

                result = result + self.lora_magnitude_vector[active_adapter](
                    x,
                    lora_A=lora_A,
                    lora_B=lora_B,
                    scaling=scaling,
                    base_layer=self.get_base_layer(),
                    base_result=base_result,
                )

        result = result.to(torch_result_dtype)

    return result


main()

# Model memory: 15.351 GB
# Inference memory: 30.134 GB
cogvideox-layerwise-lora.mp4

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 6, 2025

Just documenting for now because I don't know what approach we're going to take to deal with these kinds of problems yet.

If we enable layerwise upcasting for something like HunyuanVideo, the prompt_embeds and prompt_attention_mask are moved to fp8 dtype, and the latent input passed into the model is also fp8. This is because of lines like this:

One solution would be to overwrite the dtype method for ModelMixin to detect if a LayerwiseUpcastingHook is attached to the submodules, and if so, just read out the compute_dtype and return it. Open to suggestions and will be on the lookout for more such things. Surprisingly, this does not error out on A100s (Ampere), but does on H100 (Hopper). This was discovered during fp8 lora training run

Edit: Ah, so the reason why it worked in the benchmark script is because the x_embedder layer is skipped from layerwise upcasting. Our dtype() method therefore simply reads bfloat16 as the first parameter dtype and works with it. If we were to enable upcasting in x_embedder as well, it errors on A100 too

def test_layerwise_upcasting(storage_dtype, compute_dtype):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required because in the model tests, we create the inputs in torch.float32. When the compute_dtype is different, the inputs (such as hidden_states or encoder_hidden_states) must be casted appropriately. It is done this way so that inputs such as timestep, which is supposed to be torch.int32 or torch.int64, is not casted to compute_dtype (as the invoked function checks if the castee tensor's dtype is torch.float32 already)

@@ -1787,7 +1787,7 @@ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embeddi
def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 This change is needed for the Lumina test to not fail. I think it should be a safe change

Comment on lines 155 to 162
@unittest.skip(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the tests that don't pass, it is due to the following reasons:

  • PyTorch has certain operations unimplemented when deterministic algorithms are enabled and float8 types are used.
  • One of our Autoencoder implementations (AutoencoderOobleck) uses the deprecated torch.nn.utils.weight_norm. From some light debugging, it seems like nn::Module::to is a null-op for some reason and the weight types don't change
  • The forward pass of AutoencoderTiny ends up casting the module inputs to torch.float32 when using compute_dtype=torch.bfloat16. It is happening somewhere in the nn.Sequential's from some light debugging, but due to potentially low usage, I've skipped the test for now with instructions on how to fix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xfail might be a better option here sinch you mentioned it might be fixed in a future PyTorch release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference:

@pytest.mark.xfail(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@a-r-r-o-w
Copy link
Member Author

Thanks for the reviews everyone 🤗

@a-r-r-o-w a-r-r-o-w merged commit beacaa5 into main Jan 22, 2025
14 of 15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the layerwise-upcasting-hook branch January 22, 2025 14:19
@sayakpaul
Copy link
Member

Thanks for shipping this banger of a feature, Aryan!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

[Experimental] expose dynamic upcasting of layers as experimental APIs
6 participants