Skip to content

Commit 5c29b53

Browse files
[Refactor] refactor loaders.py to make it cleaner and leaner. (huggingface#5771)
* refactor loaders.py to make it cleaner and leaner. * refactor loaders init * inits. * textual inversion to the init. * inits. * remove certain modules from the main init. * AttnProcsLayers * fix imports * avoid circular import. * fix circular import pt 2. * address PR comments * imports * fix: imports. * remove from main init for avoiding circular deps. * remove spurious deps. * fix-copies. * fix imports. * more debug * more debug * Apply suggestions from code review * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 5d83951 commit 5c29b53

File tree

9 files changed

+3607
-3384
lines changed

9 files changed

+3607
-3384
lines changed

__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"VQModel",
9595
]
9696
)
97+
9798
_import_structure["optimization"] = [
9899
"get_constant_schedule",
99100
"get_constant_schedule_with_warmup",
@@ -103,7 +104,6 @@
103104
"get_polynomial_decay_schedule_with_warmup",
104105
"get_scheduler",
105106
]
106-
107107
_import_structure["pipelines"].extend(
108108
[
109109
"AudioPipelineOutput",

loaders.py

-3,382
This file was deleted.

loaders/__init__.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4+
from ..utils.import_utils import is_torch_available, is_transformers_available
5+
6+
7+
def text_encoder_lora_state_dict(text_encoder):
8+
deprecate(
9+
"text_encoder_load_state_dict in `models`",
10+
"0.27.0",
11+
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
12+
)
13+
state_dict = {}
14+
15+
for name, module in text_encoder_attn_modules(text_encoder):
16+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
17+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
18+
19+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
20+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
21+
22+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
23+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
24+
25+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
26+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
27+
28+
return state_dict
29+
30+
31+
if is_transformers_available():
32+
33+
def text_encoder_attn_modules(text_encoder):
34+
deprecate(
35+
"text_encoder_attn_modules in `models`",
36+
"0.27.0",
37+
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
38+
)
39+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
40+
41+
attn_modules = []
42+
43+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
44+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
45+
name = f"text_model.encoder.layers.{i}.self_attn"
46+
mod = layer.self_attn
47+
attn_modules.append((name, mod))
48+
else:
49+
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
50+
51+
return attn_modules
52+
53+
54+
_import_structure = {}
55+
56+
if is_torch_available():
57+
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
58+
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
59+
_import_structure["utils"] = ["AttnProcsLayers"]
60+
61+
if is_transformers_available():
62+
_import_structure["single_file"].extend(["FromSingleFileMixin"])
63+
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
64+
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
65+
66+
67+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
68+
if is_torch_available():
69+
from ..models.lora import text_encoder_lora_state_dict
70+
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
71+
from .unet import UNet2DConditionLoadersMixin
72+
from .utils import AttnProcsLayers
73+
74+
if is_transformers_available():
75+
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
76+
from .single_file import FromSingleFileMixin
77+
from .textual_inversion import TextualInversionLoaderMixin
78+
else:
79+
import sys
80+
81+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

0 commit comments

Comments
 (0)