Skip to content

Commit b4be422

Browse files
Terikslinoytsaban
andauthored
Kolors additional pipelines, community contrib (#11372)
* Kolors additional pipelines, community contrib --------- Co-authored-by: Teriks <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
1 parent a4f9c3c commit b4be422

8 files changed

+6517
-6
lines changed

examples/community/pipeline_controlnet_xl_kolors.py

+1,355
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_img2img.py

+1,557
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_inpaint.py

+1,871
Large diffs are not rendered by default.

examples/community/pipeline_kolors_inpainting.py

+1,728
Large diffs are not rendered by default.

src/diffusers/loaders/lora_conversion_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
433433
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
434434
if not is_sparse:
435435
# down_weight is copied to each split
436-
ait_sd.update({k: down_weight for k in ait_down_keys})
436+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
437437

438438
# up_weight is split to each split
439439
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -923,7 +923,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
923923
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
924924

925925
# down_weight is copied to each split
926-
ait_sd.update({k: down_weight for k in ait_down_keys})
926+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
927927

928928
# up_weight is split to each split
929929
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

src/diffusers/pipelines/pipeline_flax_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def load_module(name, value):
469469
class_obj = import_flax_or_no_model(pipeline_module, class_name)
470470

471471
importable_classes = ALL_IMPORTABLE_CLASSES
472-
class_candidates = {c: class_obj for c in importable_classes.keys()}
472+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
473473
else:
474474
# else we just import it from the library.
475475
library = importlib.import_module(library_name)

src/diffusers/pipelines/pipeline_loading_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,13 @@ def get_class_obj_and_candidates(
341341
pipeline_module = getattr(pipelines, library_name)
342342

343343
class_obj = getattr(pipeline_module, class_name)
344-
class_candidates = {c: class_obj for c in importable_classes.keys()}
344+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
345345
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
346346
# load custom component
347347
class_obj = get_class_from_dynamic_module(
348348
component_folder, module_file=library_name + ".py", class_name=class_name
349349
)
350-
class_candidates = {c: class_obj for c in importable_classes.keys()}
350+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
351351
else:
352352
# else we just import it from the library.
353353
library = importlib.import_module(library_name)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_save_load_optional_components(self):
205205
# set all optional components to None and update pipeline config accordingly
206206
for optional_component in pipe._optional_components:
207207
setattr(pipe, optional_component, None)
208-
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
208+
pipe.register_modules(**dict.fromkeys(pipe._optional_components))
209209

210210
inputs = self.get_dummy_inputs(torch_device)
211211
output = pipe(**inputs)[0]

0 commit comments

Comments
 (0)