75
75
_get_custom_pipeline_class ,
76
76
_get_final_device_map ,
77
77
_get_pipeline_class ,
78
+ _identify_model_variants ,
79
+ _maybe_raise_warning_for_inpainting ,
80
+ _resolve_custom_pipeline_and_cls ,
78
81
_unwrap_model ,
82
+ _update_init_kwargs_with_connected_pipeline ,
79
83
is_safetensors_compatible ,
80
84
load_sub_model ,
81
85
maybe_raise_or_warn ,
@@ -622,6 +626,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
622
626
>>> pipeline.scheduler = scheduler
623
627
```
624
628
"""
629
+ # Copy the kwargs to re-use during loading connected pipeline.
630
+ kwargs_copied = kwargs .copy ()
631
+
625
632
cache_dir = kwargs .pop ("cache_dir" , None )
626
633
force_download = kwargs .pop ("force_download" , False )
627
634
proxies = kwargs .pop ("proxies" , None )
@@ -722,33 +729,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
722
729
config_dict .pop ("_ignore_files" , None )
723
730
724
731
# 2. Define which model components should load variants
725
- # We retrieve the information by matching whether variant
726
- # model checkpoints exist in the subfolders
727
- model_variants = {}
728
- if variant is not None :
729
- for folder in os .listdir (cached_folder ):
730
- folder_path = os .path .join (cached_folder , folder )
731
- is_folder = os .path .isdir (folder_path ) and folder in config_dict
732
- variant_exists = is_folder and any (
733
- p .split ("." )[1 ].startswith (variant ) for p in os .listdir (folder_path )
734
- )
735
- if variant_exists :
736
- model_variants [folder ] = variant
732
+ # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
733
+ # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
734
+ # with variant being `"fp16"`.
735
+ model_variants = _identify_model_variants (folder = cached_folder , variant = variant , config = config_dict )
737
736
738
737
# 3. Load the pipeline class, if using custom module then load it from the hub
739
738
# if we load from explicit class, let's use it
740
- custom_class_name = None
741
- if os .path .isfile (os .path .join (cached_folder , f"{ custom_pipeline } .py" )):
742
- custom_pipeline = os .path .join (cached_folder , f"{ custom_pipeline } .py" )
743
- elif isinstance (config_dict ["_class_name" ], (list , tuple )) and os .path .isfile (
744
- os .path .join (cached_folder , f"{ config_dict ['_class_name' ][0 ]} .py" )
745
- ):
746
- custom_pipeline = os .path .join (cached_folder , f"{ config_dict ['_class_name' ][0 ]} .py" )
747
- custom_class_name = config_dict ["_class_name" ][1 ]
748
-
739
+ custom_pipeline , custom_class_name = _resolve_custom_pipeline_and_cls (
740
+ folder = cached_folder , config = config_dict , custom_pipeline = custom_pipeline
741
+ )
749
742
pipeline_class = _get_pipeline_class (
750
743
cls ,
751
- config_dict ,
744
+ config = config_dict ,
752
745
load_connected_pipeline = load_connected_pipeline ,
753
746
custom_pipeline = custom_pipeline ,
754
747
class_name = custom_class_name ,
@@ -760,23 +753,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
760
753
raise NotImplementedError ("`device_map` is not yet supported for connected pipelines." )
761
754
762
755
# DEPRECATED: To be removed in 1.0.0
763
- if pipeline_class .__name__ == "StableDiffusionInpaintPipeline" and version .parse (
764
- version .parse (config_dict ["_diffusers_version" ]).base_version
765
- ) <= version .parse ("0.5.1" ):
766
- from diffusers import StableDiffusionInpaintPipeline , StableDiffusionInpaintPipelineLegacy
767
-
768
- pipeline_class = StableDiffusionInpaintPipelineLegacy
769
-
770
- deprecation_message = (
771
- "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
772
- f" { StableDiffusionInpaintPipelineLegacy } class instead of { StableDiffusionInpaintPipeline } . For"
773
- " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
774
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
775
- f" checkpoint { pretrained_model_name_or_path } to the format of"
776
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
777
- " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
778
- )
779
- deprecate ("StableDiffusionInpaintPipelineLegacy" , "1.0.0" , deprecation_message , standard_warn = False )
756
+ # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
757
+ # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
758
+ _maybe_raise_warning_for_inpainting (
759
+ pipeline_class = pipeline_class ,
760
+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
761
+ config = config_dict ,
762
+ )
780
763
781
764
# 4. Define expected modules given pipeline signature
782
765
# and define non-None initialized modules (=`init_kwargs`)
@@ -787,7 +770,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
787
770
expected_modules , optional_kwargs = cls ._get_signature_keys (pipeline_class )
788
771
passed_class_obj = {k : kwargs .pop (k ) for k in expected_modules if k in kwargs }
789
772
passed_pipe_kwargs = {k : kwargs .pop (k ) for k in optional_kwargs if k in kwargs }
790
-
791
773
init_dict , unused_kwargs , _ = pipeline_class .extract_init_dict (config_dict , ** kwargs )
792
774
793
775
# define init kwargs and make sure that optional component modules are filtered out
@@ -847,22 +829,23 @@ def load_module(name, value):
847
829
# 7. Load each module in the pipeline
848
830
current_device_map = None
849
831
for name , (library_name , class_name ) in logging .tqdm (init_dict .items (), desc = "Loading pipeline components..." ):
832
+ # 7.1 device_map shenanigans
850
833
if final_device_map is not None and len (final_device_map ) > 0 :
851
834
component_device = final_device_map .get (name , None )
852
835
if component_device is not None :
853
836
current_device_map = {"" : component_device }
854
837
else :
855
838
current_device_map = None
856
839
857
- # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
840
+ # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
858
841
class_name = class_name [4 :] if class_name .startswith ("Flax" ) else class_name
859
842
860
- # 7.2 Define all importable classes
843
+ # 7.3 Define all importable classes
861
844
is_pipeline_module = hasattr (pipelines , library_name )
862
845
importable_classes = ALL_IMPORTABLE_CLASSES
863
846
loaded_sub_model = None
864
847
865
- # 7.3 Use passed sub model or load class_name from library_name
848
+ # 7.4 Use passed sub model or load class_name from library_name
866
849
if name in passed_class_obj :
867
850
# if the model is in a pipeline module, then we load it from the pipeline
868
851
# check that passed_class_obj has correct parent class
@@ -900,56 +883,17 @@ def load_module(name, value):
900
883
901
884
init_kwargs [name ] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
902
885
886
+ # 8. Handle connected pipelines.
903
887
if pipeline_class ._load_connected_pipes and os .path .isfile (os .path .join (cached_folder , "README.md" )):
904
- modelcard = ModelCard .load (os .path .join (cached_folder , "README.md" ))
905
- connected_pipes = {prefix : getattr (modelcard .data , prefix , [None ])[0 ] for prefix in CONNECTED_PIPES_KEYS }
906
- load_kwargs = {
907
- "cache_dir" : cache_dir ,
908
- "force_download" : force_download ,
909
- "proxies" : proxies ,
910
- "local_files_only" : local_files_only ,
911
- "token" : token ,
912
- "revision" : revision ,
913
- "torch_dtype" : torch_dtype ,
914
- "custom_pipeline" : custom_pipeline ,
915
- "custom_revision" : custom_revision ,
916
- "provider" : provider ,
917
- "sess_options" : sess_options ,
918
- "device_map" : device_map ,
919
- "max_memory" : max_memory ,
920
- "offload_folder" : offload_folder ,
921
- "offload_state_dict" : offload_state_dict ,
922
- "low_cpu_mem_usage" : low_cpu_mem_usage ,
923
- "variant" : variant ,
924
- "use_safetensors" : use_safetensors ,
925
- }
926
-
927
- def get_connected_passed_kwargs (prefix ):
928
- connected_passed_class_obj = {
929
- k .replace (f"{ prefix } _" , "" ): w for k , w in passed_class_obj .items () if k .split ("_" )[0 ] == prefix
930
- }
931
- connected_passed_pipe_kwargs = {
932
- k .replace (f"{ prefix } _" , "" ): w for k , w in passed_pipe_kwargs .items () if k .split ("_" )[0 ] == prefix
933
- }
934
-
935
- connected_passed_kwargs = {** connected_passed_class_obj , ** connected_passed_pipe_kwargs }
936
- return connected_passed_kwargs
937
-
938
- connected_pipes = {
939
- prefix : DiffusionPipeline .from_pretrained (
940
- repo_id , ** load_kwargs .copy (), ** get_connected_passed_kwargs (prefix )
941
- )
942
- for prefix , repo_id in connected_pipes .items ()
943
- if repo_id is not None
944
- }
945
-
946
- for prefix , connected_pipe in connected_pipes .items ():
947
- # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
948
- init_kwargs .update (
949
- {"_" .join ([prefix , name ]): component for name , component in connected_pipe .components .items ()}
950
- )
888
+ init_kwargs = _update_init_kwargs_with_connected_pipeline (
889
+ init_kwargs = init_kwargs ,
890
+ passed_pipe_kwargs = passed_pipe_kwargs ,
891
+ passed_class_objs = passed_class_obj ,
892
+ folder = cached_folder ,
893
+ ** kwargs_copied ,
894
+ )
951
895
952
- # 8 . Potentially add passed objects if expected
896
+ # 9 . Potentially add passed objects if expected
953
897
missing_modules = set (expected_modules ) - set (init_kwargs .keys ())
954
898
passed_modules = list (passed_class_obj .keys ())
955
899
optional_modules = pipeline_class ._optional_components
0 commit comments