Skip to content

Commit 32d6492

Browse files
authored
[Core] Tear apart from_pretrained() of DiffusionPipeline (#8967)
* break from_pretrained part i. * part ii. * init_kwargs * remove _fetch_init_kwargs * type annotation * dtyle * switch to _check_and_update_init_kwargs_for_missing_modules. * remove _check_and_update_init_kwargs_for_missing_modules. * use pipeline_loading_kwargs. * remove _determine_current_device_map. * remove _filter_null_components. * device_map fix. * fix _update_init_kwargs_with_connected_pipeline. * better handle custom pipeline. * explain _maybe_raise_warning_for_inpainting. * add example for model variant. * fix
1 parent 43f1090 commit 32d6492

File tree

2 files changed

+126
-92
lines changed

2 files changed

+126
-92
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Any, Dict, List, Optional, Union
2323

2424
import torch
25-
from huggingface_hub import model_info
25+
from huggingface_hub import ModelCard, model_info
2626
from huggingface_hub.utils import validate_hf_hub_args
2727
from packaging import version
2828

@@ -33,6 +33,7 @@
3333
ONNX_WEIGHTS_NAME,
3434
SAFETENSORS_WEIGHTS_NAME,
3535
WEIGHTS_NAME,
36+
deprecate,
3637
get_class_from_dynamic_module,
3738
is_accelerate_available,
3839
is_peft_available,
@@ -746,3 +747,92 @@ def _fetch_class_library_tuple(module):
746747
class_name = not_compiled_module.__class__.__name__
747748

748749
return (library, class_name)
750+
751+
752+
def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
753+
model_variants = {}
754+
if variant is not None:
755+
for folder in os.listdir(folder):
756+
folder_path = os.path.join(folder, folder)
757+
is_folder = os.path.isdir(folder_path) and folder in config
758+
variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
759+
if variant_exists:
760+
model_variants[folder] = variant
761+
return model_variants
762+
763+
764+
def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
765+
custom_class_name = None
766+
if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
767+
custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
768+
elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
769+
os.path.join(folder, f"{config['_class_name'][0]}.py")
770+
):
771+
custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
772+
custom_class_name = config["_class_name"][1]
773+
774+
return custom_pipeline, custom_class_name
775+
776+
777+
def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
778+
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
779+
version.parse(config["_diffusers_version"]).base_version
780+
) <= version.parse("0.5.1"):
781+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
782+
783+
pipeline_class = StableDiffusionInpaintPipelineLegacy
784+
785+
deprecation_message = (
786+
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
787+
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
788+
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
789+
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
790+
f" checkpoint {pretrained_model_name_or_path} to the format of"
791+
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
792+
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
793+
)
794+
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
795+
796+
797+
def _update_init_kwargs_with_connected_pipeline(
798+
init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
799+
) -> dict:
800+
from .pipeline_utils import DiffusionPipeline
801+
802+
modelcard = ModelCard.load(os.path.join(folder, "README.md"))
803+
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
804+
805+
# We don't scheduler argument to match the existing logic:
806+
# https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
807+
pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
808+
if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
809+
for k in pipeline_loading_kwargs:
810+
if "scheduler" in k:
811+
_ = pipeline_loading_kwargs_cp.pop(k)
812+
813+
def get_connected_passed_kwargs(prefix):
814+
connected_passed_class_obj = {
815+
k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
816+
}
817+
connected_passed_pipe_kwargs = {
818+
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
819+
}
820+
821+
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
822+
return connected_passed_kwargs
823+
824+
connected_pipes = {
825+
prefix: DiffusionPipeline.from_pretrained(
826+
repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
827+
)
828+
for prefix, repo_id in connected_pipes.items()
829+
if repo_id is not None
830+
}
831+
832+
for prefix, connected_pipe in connected_pipes.items():
833+
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
834+
init_kwargs.update(
835+
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
836+
)
837+
838+
return init_kwargs

src/diffusers/pipelines/pipeline_utils.py

+35-91
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@
7575
_get_custom_pipeline_class,
7676
_get_final_device_map,
7777
_get_pipeline_class,
78+
_identify_model_variants,
79+
_maybe_raise_warning_for_inpainting,
80+
_resolve_custom_pipeline_and_cls,
7881
_unwrap_model,
82+
_update_init_kwargs_with_connected_pipeline,
7983
is_safetensors_compatible,
8084
load_sub_model,
8185
maybe_raise_or_warn,
@@ -622,6 +626,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
622626
>>> pipeline.scheduler = scheduler
623627
```
624628
"""
629+
# Copy the kwargs to re-use during loading connected pipeline.
630+
kwargs_copied = kwargs.copy()
631+
625632
cache_dir = kwargs.pop("cache_dir", None)
626633
force_download = kwargs.pop("force_download", False)
627634
proxies = kwargs.pop("proxies", None)
@@ -722,33 +729,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
722729
config_dict.pop("_ignore_files", None)
723730

724731
# 2. Define which model components should load variants
725-
# We retrieve the information by matching whether variant
726-
# model checkpoints exist in the subfolders
727-
model_variants = {}
728-
if variant is not None:
729-
for folder in os.listdir(cached_folder):
730-
folder_path = os.path.join(cached_folder, folder)
731-
is_folder = os.path.isdir(folder_path) and folder in config_dict
732-
variant_exists = is_folder and any(
733-
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
734-
)
735-
if variant_exists:
736-
model_variants[folder] = variant
732+
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
733+
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
734+
# with variant being `"fp16"`.
735+
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
737736

738737
# 3. Load the pipeline class, if using custom module then load it from the hub
739738
# if we load from explicit class, let's use it
740-
custom_class_name = None
741-
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
742-
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
743-
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
744-
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
745-
):
746-
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
747-
custom_class_name = config_dict["_class_name"][1]
748-
739+
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
740+
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
741+
)
749742
pipeline_class = _get_pipeline_class(
750743
cls,
751-
config_dict,
744+
config=config_dict,
752745
load_connected_pipeline=load_connected_pipeline,
753746
custom_pipeline=custom_pipeline,
754747
class_name=custom_class_name,
@@ -760,23 +753,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
760753
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
761754

762755
# DEPRECATED: To be removed in 1.0.0
763-
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
764-
version.parse(config_dict["_diffusers_version"]).base_version
765-
) <= version.parse("0.5.1"):
766-
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
767-
768-
pipeline_class = StableDiffusionInpaintPipelineLegacy
769-
770-
deprecation_message = (
771-
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
772-
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
773-
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
774-
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
775-
f" checkpoint {pretrained_model_name_or_path} to the format of"
776-
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
777-
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
778-
)
779-
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
756+
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
757+
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
758+
_maybe_raise_warning_for_inpainting(
759+
pipeline_class=pipeline_class,
760+
pretrained_model_name_or_path=pretrained_model_name_or_path,
761+
config=config_dict,
762+
)
780763

781764
# 4. Define expected modules given pipeline signature
782765
# and define non-None initialized modules (=`init_kwargs`)
@@ -787,7 +770,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
787770
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
788771
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
789772
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
790-
791773
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
792774

793775
# define init kwargs and make sure that optional component modules are filtered out
@@ -847,22 +829,23 @@ def load_module(name, value):
847829
# 7. Load each module in the pipeline
848830
current_device_map = None
849831
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
832+
# 7.1 device_map shenanigans
850833
if final_device_map is not None and len(final_device_map) > 0:
851834
component_device = final_device_map.get(name, None)
852835
if component_device is not None:
853836
current_device_map = {"": component_device}
854837
else:
855838
current_device_map = None
856839

857-
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
840+
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
858841
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
859842

860-
# 7.2 Define all importable classes
843+
# 7.3 Define all importable classes
861844
is_pipeline_module = hasattr(pipelines, library_name)
862845
importable_classes = ALL_IMPORTABLE_CLASSES
863846
loaded_sub_model = None
864847

865-
# 7.3 Use passed sub model or load class_name from library_name
848+
# 7.4 Use passed sub model or load class_name from library_name
866849
if name in passed_class_obj:
867850
# if the model is in a pipeline module, then we load it from the pipeline
868851
# check that passed_class_obj has correct parent class
@@ -900,56 +883,17 @@ def load_module(name, value):
900883

901884
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
902885

886+
# 8. Handle connected pipelines.
903887
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
904-
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
905-
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
906-
load_kwargs = {
907-
"cache_dir": cache_dir,
908-
"force_download": force_download,
909-
"proxies": proxies,
910-
"local_files_only": local_files_only,
911-
"token": token,
912-
"revision": revision,
913-
"torch_dtype": torch_dtype,
914-
"custom_pipeline": custom_pipeline,
915-
"custom_revision": custom_revision,
916-
"provider": provider,
917-
"sess_options": sess_options,
918-
"device_map": device_map,
919-
"max_memory": max_memory,
920-
"offload_folder": offload_folder,
921-
"offload_state_dict": offload_state_dict,
922-
"low_cpu_mem_usage": low_cpu_mem_usage,
923-
"variant": variant,
924-
"use_safetensors": use_safetensors,
925-
}
926-
927-
def get_connected_passed_kwargs(prefix):
928-
connected_passed_class_obj = {
929-
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
930-
}
931-
connected_passed_pipe_kwargs = {
932-
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
933-
}
934-
935-
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
936-
return connected_passed_kwargs
937-
938-
connected_pipes = {
939-
prefix: DiffusionPipeline.from_pretrained(
940-
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
941-
)
942-
for prefix, repo_id in connected_pipes.items()
943-
if repo_id is not None
944-
}
945-
946-
for prefix, connected_pipe in connected_pipes.items():
947-
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
948-
init_kwargs.update(
949-
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
950-
)
888+
init_kwargs = _update_init_kwargs_with_connected_pipeline(
889+
init_kwargs=init_kwargs,
890+
passed_pipe_kwargs=passed_pipe_kwargs,
891+
passed_class_objs=passed_class_obj,
892+
folder=cached_folder,
893+
**kwargs_copied,
894+
)
951895

952-
# 8. Potentially add passed objects if expected
896+
# 9. Potentially add passed objects if expected
953897
missing_modules = set(expected_modules) - set(init_kwargs.keys())
954898
passed_modules = list(passed_class_obj.keys())
955899
optional_modules = pipeline_class._optional_components

0 commit comments

Comments
 (0)