13
13
# limitations under the License.
14
14
15
15
import re
16
- from typing import Optional , Tuple , Type
16
+ from typing import Optional , Tuple , Type , Union
17
17
18
18
import torch
19
19
25
25
26
26
27
27
# fmt: off
28
- _SUPPORTED_PYTORCH_LAYERS = (
28
+ SUPPORTED_PYTORCH_LAYERS = (
29
29
torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
30
30
torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
31
31
torch .nn .Linear ,
32
32
)
33
33
34
- _DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed" , "patch_embed" , "norm" )
34
+ DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed" , "patch_embed" , "norm" , "^proj_in$" , "^proj_out$ " )
35
35
# fmt: on
36
36
37
37
@@ -74,8 +74,8 @@ def apply_layerwise_upcasting(
74
74
module : torch .nn .Module ,
75
75
storage_dtype : torch .dtype ,
76
76
compute_dtype : torch .dtype ,
77
- skip_modules_pattern : Optional [ Tuple [str ]] = _DEFAULT_SKIP_MODULES_PATTERN ,
78
- skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ]]] = None ,
77
+ skip_modules_pattern : Union [ str , Tuple [str , ... ]] = "default" ,
78
+ skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ... ]] = None ,
79
79
non_blocking : bool = False ,
80
80
_prefix : str = "" ,
81
81
) -> None :
@@ -87,13 +87,14 @@ def apply_layerwise_upcasting(
87
87
88
88
```python
89
89
>>> import torch
90
- >>> from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
90
+ >>> from diffusers import CogVideoXTransformer3DModel
91
91
92
- >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
93
- >>> pipe.to("cuda")
92
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
93
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
94
+ ... )
94
95
95
96
>>> apply_layerwise_upcasting(
96
- ... pipe. transformer,
97
+ ... transformer,
97
98
... storage_dtype=torch.float8_e4m3fn,
98
99
... compute_dtype=torch.bfloat16,
99
100
... skip_modules_pattern=["patch_embed", "norm"],
@@ -109,13 +110,17 @@ def apply_layerwise_upcasting(
109
110
The dtype to cast the module to before/after the forward pass for storage.
110
111
compute_dtype (`torch.dtype`):
111
112
The dtype to cast the module to during the forward pass for computation.
112
- skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
113
- A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
114
- skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
113
+ skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
114
+ A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
115
+ to `"default"`, the default patterns are used.
116
+ skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
115
117
A list of module classes to skip during the layerwise upcasting process.
116
118
non_blocking (`bool`, defaults to `False`):
117
119
If `True`, the weight casting operations are non-blocking.
118
120
"""
121
+ if skip_modules_pattern == "default" :
122
+ skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
123
+
119
124
if skip_modules_classes is None and skip_modules_pattern is None :
120
125
apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
121
126
return
@@ -127,7 +132,7 @@ def apply_layerwise_upcasting(
127
132
logger .debug (f'Skipping layerwise upcasting for layer "{ _prefix } "' )
128
133
return
129
134
130
- if isinstance (module , _SUPPORTED_PYTORCH_LAYERS ):
135
+ if isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
131
136
logger .debug (f'Applying layerwise upcasting to layer "{ _prefix } "' )
132
137
apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
133
138
return
0 commit comments