Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 17, 2025
1 parent 3ae32b4 commit bf797e7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
28 changes: 24 additions & 4 deletions src/diffusers/hooks/layerwise_upcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def apply_layerwise_upcasting(
skip_modules_pattern: Union[str, Tuple[str, ...]] = "default",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
r"""
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
Expand All @@ -97,7 +96,7 @@ def apply_layerwise_upcasting(
... transformer,
... storage_dtype=torch.float8_e4m3fn,
... compute_dtype=torch.bfloat16,
... skip_modules_pattern=["patch_embed", "norm"],
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
... non_blocking=True,
... )
```
Expand All @@ -112,7 +111,9 @@ def apply_layerwise_upcasting(
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
to `"default"`, the default patterns are used.
to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise upcasting process.
non_blocking (`bool`, defaults to `False`):
Expand All @@ -125,6 +126,25 @@ def apply_layerwise_upcasting(
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
return

_apply_layerwise_upcasting(
module,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
)


def _apply_layerwise_upcasting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
)
Expand All @@ -139,7 +159,7 @@ def apply_layerwise_upcasting(

for name, submodule in module.named_children():
layer_name = f"{_prefix}.{name}" if _prefix else name
apply_layerwise_upcasting(
_apply_layerwise_upcasting(
submodule,
storage_dtype,
compute_dtype,
Expand Down
13 changes: 5 additions & 8 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,25 +347,20 @@ def enable_layerwise_upcasting(
By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding,
positional embedding and normalization layers. This is because these layers are most likely precision-critical
for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`,
or call [`~apply_layerwise_upcasting`] with custom arguments.
or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments.
Example:
Using [`~models.ModelMixin.enable_layerwise_upcasting`]:
```python
>>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> # Enable layerwise upcasting via the model, which ignores certain modules by default
>>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
>>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function
>>> apply_layerwise_upcasting(
... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm"]
... )
```
Args:
Expand All @@ -374,7 +369,9 @@ def enable_layerwise_upcasting(
compute_dtype (`torch.dtype`):
The dtype to which the model weights should be cast during the forward pass.
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
layers.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
A list of module classes to skip during the layerwise upcasting process.
non_blocking (`bool`, *optional*, defaults to `False`):
Expand Down

0 comments on commit bf797e7

Please sign in to comment.