From bf797e7746d6bcbc59948d0a8c3151a8a8e9e3cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 08:45:59 +0100 Subject: [PATCH] refactor --- src/diffusers/hooks/layerwise_upcasting.py | 28 ++++++++++++++++++---- src/diffusers/models/modeling_utils.py | 13 ++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 7ee920c1e002..3d85b56db72b 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -77,7 +77,6 @@ def apply_layerwise_upcasting( skip_modules_pattern: Union[str, Tuple[str, ...]] = "default", skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, non_blocking: bool = False, - _prefix: str = "", ) -> None: r""" Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any @@ -97,7 +96,7 @@ def apply_layerwise_upcasting( ... transformer, ... storage_dtype=torch.float8_e4m3fn, ... compute_dtype=torch.bfloat16, - ... skip_modules_pattern=["patch_embed", "norm"], + ... skip_modules_pattern=["patch_embed", "norm", "proj_out"], ... non_blocking=True, ... ) ``` @@ -112,7 +111,9 @@ def apply_layerwise_upcasting( The dtype to cast the module to during the forward pass for computation. skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set - to `"default"`, the default patterns are used. + to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` + alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module + instead of its internal submodules. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, defaults to `False`): @@ -125,6 +126,25 @@ def apply_layerwise_upcasting( apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) return + _apply_layerwise_upcasting( + module, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + ) + + +def _apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, + non_blocking: bool = False, + _prefix: str = "", +) -> None: should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) ) @@ -139,7 +159,7 @@ def apply_layerwise_upcasting( for name, submodule in module.named_children(): layer_name = f"{_prefix}.{name}" if _prefix else name - apply_layerwise_upcasting( + _apply_layerwise_upcasting( submodule, storage_dtype, compute_dtype, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 995560ecd0ec..402ad7788e5f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -347,13 +347,13 @@ def enable_layerwise_upcasting( By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding, positional embedding and normalization layers. This is because these layers are most likely precision-critical for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`, - or call [`~apply_layerwise_upcasting`] with custom arguments. + or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. Example: Using [`~models.ModelMixin.enable_layerwise_upcasting`]: ```python - >>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting + >>> from diffusers import CogVideoXTransformer3DModel >>> transformer = CogVideoXTransformer3DModel.from_pretrained( ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 @@ -361,11 +361,6 @@ def enable_layerwise_upcasting( >>> # Enable layerwise upcasting via the model, which ignores certain modules by default >>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) - - >>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function - >>> apply_layerwise_upcasting( - ... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm"] - ... ) ``` Args: @@ -374,7 +369,9 @@ def enable_layerwise_upcasting( compute_dtype (`torch.dtype`): The dtype to which the model weights should be cast during the forward pass. skip_modules_pattern (`Tuple[str, ...]`, *optional*): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If + set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT + layers. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, *optional*, defaults to `False`):