Skip to content

Commit bf797e7

Browse files
committed
refactor
1 parent 3ae32b4 commit bf797e7

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/diffusers/hooks/layerwise_upcasting.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def apply_layerwise_upcasting(
7777
skip_modules_pattern: Union[str, Tuple[str, ...]] = "default",
7878
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
7979
non_blocking: bool = False,
80-
_prefix: str = "",
8180
) -> None:
8281
r"""
8382
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
@@ -97,7 +96,7 @@ def apply_layerwise_upcasting(
9796
... transformer,
9897
... storage_dtype=torch.float8_e4m3fn,
9998
... compute_dtype=torch.bfloat16,
100-
... skip_modules_pattern=["patch_embed", "norm"],
99+
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
101100
... non_blocking=True,
102101
... )
103102
```
@@ -112,7 +111,9 @@ def apply_layerwise_upcasting(
112111
The dtype to cast the module to during the forward pass for computation.
113112
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
114113
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
115-
to `"default"`, the default patterns are used.
114+
to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
115+
alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module
116+
instead of its internal submodules.
116117
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
117118
A list of module classes to skip during the layerwise upcasting process.
118119
non_blocking (`bool`, defaults to `False`):
@@ -125,6 +126,25 @@ def apply_layerwise_upcasting(
125126
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
126127
return
127128

129+
_apply_layerwise_upcasting(
130+
module,
131+
storage_dtype,
132+
compute_dtype,
133+
skip_modules_pattern,
134+
skip_modules_classes,
135+
non_blocking,
136+
)
137+
138+
139+
def _apply_layerwise_upcasting(
140+
module: torch.nn.Module,
141+
storage_dtype: torch.dtype,
142+
compute_dtype: torch.dtype,
143+
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
144+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
145+
non_blocking: bool = False,
146+
_prefix: str = "",
147+
) -> None:
128148
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
129149
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
130150
)
@@ -139,7 +159,7 @@ def apply_layerwise_upcasting(
139159

140160
for name, submodule in module.named_children():
141161
layer_name = f"{_prefix}.{name}" if _prefix else name
142-
apply_layerwise_upcasting(
162+
_apply_layerwise_upcasting(
143163
submodule,
144164
storage_dtype,
145165
compute_dtype,

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,25 +347,20 @@ def enable_layerwise_upcasting(
347347
By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding,
348348
positional embedding and normalization layers. This is because these layers are most likely precision-critical
349349
for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`,
350-
or call [`~apply_layerwise_upcasting`] with custom arguments.
350+
or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments.
351351
352352
Example:
353353
Using [`~models.ModelMixin.enable_layerwise_upcasting`]:
354354
355355
```python
356-
>>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting
356+
>>> from diffusers import CogVideoXTransformer3DModel
357357
358358
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
359359
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
360360
... )
361361
362362
>>> # Enable layerwise upcasting via the model, which ignores certain modules by default
363363
>>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
364-
365-
>>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function
366-
>>> apply_layerwise_upcasting(
367-
... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm"]
368-
... )
369364
```
370365
371366
Args:
@@ -374,7 +369,9 @@ def enable_layerwise_upcasting(
374369
compute_dtype (`torch.dtype`):
375370
The dtype to which the model weights should be cast during the forward pass.
376371
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
377-
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
372+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If
373+
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
374+
layers.
378375
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
379376
A list of module classes to skip during the layerwise upcasting process.
380377
non_blocking (`bool`, *optional*, defaults to `False`):

0 commit comments

Comments
 (0)