Skip to content

Commit 3ae32b4

Browse files
committed
update
1 parent 19901e7 commit 3ae32b4

File tree

3 files changed

+47
-23
lines changed

3 files changed

+47
-23
lines changed

docs/source/en/optimization/memory.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,18 @@ Typically, inference on most models is done with `torch.float16` or `torch.bfloa
166166

167167
```python
168168
import torch
169-
from diffusers import CogVideoXPipeline
169+
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
170170
from diffusers.utils import export_to_video
171171

172-
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
172+
model_id = "THUDM/CogVideoX-5b"
173+
174+
# Load the model in bfloat16 and enable layerwise upcasting
175+
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
176+
transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
177+
178+
# Load the pipeline
179+
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
173180
pipe.to("cuda")
174-
pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
175181

176182
prompt = (
177183
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "

src/diffusers/hooks/layerwise_upcasting.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16-
from typing import Optional, Tuple, Type
16+
from typing import Optional, Tuple, Type, Union
1717

1818
import torch
1919

@@ -25,13 +25,13 @@
2525

2626

2727
# fmt: off
28-
_SUPPORTED_PYTORCH_LAYERS = (
28+
SUPPORTED_PYTORCH_LAYERS = (
2929
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3030
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
3131
torch.nn.Linear,
3232
)
3333

34-
_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
34+
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
3535
# fmt: on
3636

3737

@@ -74,8 +74,8 @@ def apply_layerwise_upcasting(
7474
module: torch.nn.Module,
7575
storage_dtype: torch.dtype,
7676
compute_dtype: torch.dtype,
77-
skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
78-
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
77+
skip_modules_pattern: Union[str, Tuple[str, ...]] = "default",
78+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
7979
non_blocking: bool = False,
8080
_prefix: str = "",
8181
) -> None:
@@ -87,13 +87,14 @@ def apply_layerwise_upcasting(
8787
8888
```python
8989
>>> import torch
90-
>>> from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
90+
>>> from diffusers import CogVideoXTransformer3DModel
9191
92-
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
93-
>>> pipe.to("cuda")
92+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
93+
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
94+
... )
9495
9596
>>> apply_layerwise_upcasting(
96-
... pipe.transformer,
97+
... transformer,
9798
... storage_dtype=torch.float8_e4m3fn,
9899
... compute_dtype=torch.bfloat16,
99100
... skip_modules_pattern=["patch_embed", "norm"],
@@ -109,13 +110,17 @@ def apply_layerwise_upcasting(
109110
The dtype to cast the module to before/after the forward pass for storage.
110111
compute_dtype (`torch.dtype`):
111112
The dtype to cast the module to during the forward pass for computation.
112-
skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
113-
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
114-
skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
113+
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
114+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
115+
to `"default"`, the default patterns are used.
116+
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
115117
A list of module classes to skip during the layerwise upcasting process.
116118
non_blocking (`bool`, defaults to `False`):
117119
If `True`, the weight casting operations are non-blocking.
118120
"""
121+
if skip_modules_pattern == "default":
122+
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
123+
119124
if skip_modules_classes is None and skip_modules_pattern is None:
120125
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
121126
return
@@ -127,7 +132,7 @@ def apply_layerwise_upcasting(
127132
logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
128133
return
129134

130-
if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
135+
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
131136
logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
132137
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
133138
return

src/diffusers/models/modeling_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
104104
"""
105105
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
106106
"""
107+
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting)
108+
if isinstance(parameter, nn.Module):
109+
for name, submodule in parameter.named_modules():
110+
if not hasattr(submodule, "_diffusers_hook"):
111+
continue
112+
registry = submodule._diffusers_hook
113+
hook = registry.get_hook("layerwise_upcasting")
114+
if hook is not None:
115+
return hook.compute_dtype
116+
117+
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
107118
last_dtype = None
108119
for param in parameter.parameters():
109120
last_dtype = param.dtype
@@ -321,8 +332,8 @@ def enable_layerwise_upcasting(
321332
self,
322333
storage_dtype: torch.dtype = torch.float8_e4m3fn,
323334
compute_dtype: Optional[torch.dtype] = None,
324-
skip_modules_pattern: Optional[List[str]] = None,
325-
skip_modules_classes: Optional[List[Type[torch.nn.Module]]] = None,
335+
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
336+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
326337
non_blocking: bool = False,
327338
) -> None:
328339
r"""
@@ -362,22 +373,24 @@ def enable_layerwise_upcasting(
362373
The dtype to which the model should be cast for storage.
363374
compute_dtype (`torch.dtype`):
364375
The dtype to which the model weights should be cast during the forward pass.
365-
skip_modules_pattern (`List[str]`, *optional*):
376+
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
366377
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
367-
skip_modules_classes (`List[Type[torch.nn.Module]]`, *optional*):
378+
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
368379
A list of module classes to skip during the layerwise upcasting process.
369380
non_blocking (`bool`, *optional*, defaults to `False`):
370381
If `True`, the weight casting operations are non-blocking.
371382
"""
372383

373384
user_provided_patterns = True
374385
if skip_modules_pattern is None:
375-
skip_modules_pattern = []
386+
from ..hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN
387+
388+
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
376389
user_provided_patterns = False
377390
if self._keep_in_fp32_modules is not None:
378-
skip_modules_pattern.extend(self._keep_in_fp32_modules)
391+
skip_modules_pattern += tuple(self._keep_in_fp32_modules)
379392
if self._always_upcast_modules is not None:
380-
skip_modules_pattern.extend(self._always_upcast_modules)
393+
skip_modules_pattern += tuple(self._always_upcast_modules)
381394
skip_modules_pattern = tuple(set(skip_modules_pattern))
382395

383396
if is_peft_available() and not user_provided_patterns:

0 commit comments

Comments
 (0)