@@ -77,7 +77,6 @@ def apply_layerwise_upcasting(
77
77
skip_modules_pattern : Union [str , Tuple [str , ...]] = "default" ,
78
78
skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ...]] = None ,
79
79
non_blocking : bool = False ,
80
- _prefix : str = "" ,
81
80
) -> None :
82
81
r"""
83
82
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
@@ -97,7 +96,7 @@ def apply_layerwise_upcasting(
97
96
... transformer,
98
97
... storage_dtype=torch.float8_e4m3fn,
99
98
... compute_dtype=torch.bfloat16,
100
- ... skip_modules_pattern=["patch_embed", "norm"],
99
+ ... skip_modules_pattern=["patch_embed", "norm", "proj_out" ],
101
100
... non_blocking=True,
102
101
... )
103
102
```
@@ -112,7 +111,9 @@ def apply_layerwise_upcasting(
112
111
The dtype to cast the module to during the forward pass for computation.
113
112
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`):
114
113
A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set
115
- to `"default"`, the default patterns are used.
114
+ to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
115
+ alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module
116
+ instead of its internal submodules.
116
117
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
117
118
A list of module classes to skip during the layerwise upcasting process.
118
119
non_blocking (`bool`, defaults to `False`):
@@ -125,6 +126,25 @@ def apply_layerwise_upcasting(
125
126
apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
126
127
return
127
128
129
+ _apply_layerwise_upcasting (
130
+ module ,
131
+ storage_dtype ,
132
+ compute_dtype ,
133
+ skip_modules_pattern ,
134
+ skip_modules_classes ,
135
+ non_blocking ,
136
+ )
137
+
138
+
139
+ def _apply_layerwise_upcasting (
140
+ module : torch .nn .Module ,
141
+ storage_dtype : torch .dtype ,
142
+ compute_dtype : torch .dtype ,
143
+ skip_modules_pattern : Optional [Tuple [str , ...]] = None ,
144
+ skip_modules_classes : Optional [Tuple [Type [torch .nn .Module ], ...]] = None ,
145
+ non_blocking : bool = False ,
146
+ _prefix : str = "" ,
147
+ ) -> None :
128
148
should_skip = (skip_modules_classes is not None and isinstance (module , skip_modules_classes )) or (
129
149
skip_modules_pattern is not None and any (re .search (pattern , _prefix ) for pattern in skip_modules_pattern )
130
150
)
@@ -139,7 +159,7 @@ def apply_layerwise_upcasting(
139
159
140
160
for name , submodule in module .named_children ():
141
161
layer_name = f"{ _prefix } .{ name } " if _prefix else name
142
- apply_layerwise_upcasting (
162
+ _apply_layerwise_upcasting (
143
163
submodule ,
144
164
storage_dtype ,
145
165
compute_dtype ,
0 commit comments