Skip to content

Commit a4f9c3c

Browse files
authored
[Feature] Added Xlab Controlnet support (#11249)
update
1 parent 4b60f4b commit a4f9c3c

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

+29-22
Original file line numberDiff line numberDiff line change
@@ -800,17 +800,20 @@ def __call__(
800800
)
801801
height, width = control_image.shape[-2:]
802802

803-
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
804-
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
805-
806-
height_control_image, width_control_image = control_image.shape[2:]
807-
control_image = self._pack_latents(
808-
control_image,
809-
batch_size * num_images_per_prompt,
810-
num_channels_latents,
811-
height_control_image,
812-
width_control_image,
813-
)
803+
# xlab controlnet has a input_hint_block and instantx controlnet does not
804+
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
805+
if self.controlnet.input_hint_block is None:
806+
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
807+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
808+
809+
height_control_image, width_control_image = control_image.shape[2:]
810+
control_image = self._pack_latents(
811+
control_image,
812+
batch_size * num_images_per_prompt,
813+
num_channels_latents,
814+
height_control_image,
815+
width_control_image,
816+
)
814817

815818
if control_mode is not None:
816819
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
@@ -819,7 +822,9 @@ def __call__(
819822
elif isinstance(self.controlnet, FluxMultiControlNetModel):
820823
control_images = []
821824

822-
for control_image_ in control_image:
825+
# xlab controlnet has a input_hint_block and instantx controlnet does not
826+
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
827+
for i, control_image_ in enumerate(control_image):
823828
control_image_ = self.prepare_image(
824829
image=control_image_,
825830
width=width,
@@ -831,17 +836,18 @@ def __call__(
831836
)
832837
height, width = control_image_.shape[-2:]
833838

834-
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
835-
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
839+
if self.controlnet.nets[0].input_hint_block is None:
840+
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
841+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
836842

837-
height_control_image, width_control_image = control_image_.shape[2:]
838-
control_image_ = self._pack_latents(
839-
control_image_,
840-
batch_size * num_images_per_prompt,
841-
num_channels_latents,
842-
height_control_image,
843-
width_control_image,
844-
)
843+
height_control_image, width_control_image = control_image_.shape[2:]
844+
control_image_ = self._pack_latents(
845+
control_image_,
846+
batch_size * num_images_per_prompt,
847+
num_channels_latents,
848+
height_control_image,
849+
width_control_image,
850+
)
845851

846852
control_images.append(control_image_)
847853

@@ -955,6 +961,7 @@ def __call__(
955961
img_ids=latent_image_ids,
956962
joint_attention_kwargs=self.joint_attention_kwargs,
957963
return_dict=False,
964+
controlnet_blocks_repeat=controlnet_blocks_repeat,
958965
)[0]
959966

960967
latents_dtype = latents.dtype

0 commit comments

Comments
 (0)