|
| 1 | +from typing import TYPE_CHECKING |
| 2 | + |
| 3 | +from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate |
| 4 | +from ..utils.import_utils import is_torch_available, is_transformers_available |
| 5 | + |
| 6 | + |
| 7 | +def text_encoder_lora_state_dict(text_encoder): |
| 8 | + deprecate( |
| 9 | + "text_encoder_load_state_dict in `models`", |
| 10 | + "0.27.0", |
| 11 | + "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", |
| 12 | + ) |
| 13 | + state_dict = {} |
| 14 | + |
| 15 | + for name, module in text_encoder_attn_modules(text_encoder): |
| 16 | + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): |
| 17 | + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v |
| 18 | + |
| 19 | + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): |
| 20 | + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v |
| 21 | + |
| 22 | + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): |
| 23 | + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v |
| 24 | + |
| 25 | + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): |
| 26 | + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v |
| 27 | + |
| 28 | + return state_dict |
| 29 | + |
| 30 | + |
| 31 | +if is_transformers_available(): |
| 32 | + |
| 33 | + def text_encoder_attn_modules(text_encoder): |
| 34 | + deprecate( |
| 35 | + "text_encoder_attn_modules in `models`", |
| 36 | + "0.27.0", |
| 37 | + "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", |
| 38 | + ) |
| 39 | + from transformers import CLIPTextModel, CLIPTextModelWithProjection |
| 40 | + |
| 41 | + attn_modules = [] |
| 42 | + |
| 43 | + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): |
| 44 | + for i, layer in enumerate(text_encoder.text_model.encoder.layers): |
| 45 | + name = f"text_model.encoder.layers.{i}.self_attn" |
| 46 | + mod = layer.self_attn |
| 47 | + attn_modules.append((name, mod)) |
| 48 | + else: |
| 49 | + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") |
| 50 | + |
| 51 | + return attn_modules |
| 52 | + |
| 53 | + |
| 54 | +_import_structure = {} |
| 55 | + |
| 56 | +if is_torch_available(): |
| 57 | + _import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"] |
| 58 | + _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] |
| 59 | + _import_structure["utils"] = ["AttnProcsLayers"] |
| 60 | + |
| 61 | + if is_transformers_available(): |
| 62 | + _import_structure["single_file"].extend(["FromSingleFileMixin"]) |
| 63 | + _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] |
| 64 | + _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] |
| 65 | + |
| 66 | + |
| 67 | +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: |
| 68 | + if is_torch_available(): |
| 69 | + from ..models.lora import text_encoder_lora_state_dict |
| 70 | + from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin |
| 71 | + from .unet import UNet2DConditionLoadersMixin |
| 72 | + from .utils import AttnProcsLayers |
| 73 | + |
| 74 | + if is_transformers_available(): |
| 75 | + from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin |
| 76 | + from .single_file import FromSingleFileMixin |
| 77 | + from .textual_inversion import TextualInversionLoaderMixin |
| 78 | +else: |
| 79 | + import sys |
| 80 | + |
| 81 | + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
0 commit comments