@@ -800,17 +800,20 @@ def __call__(
800
800
)
801
801
height , width = control_image .shape [- 2 :]
802
802
803
- control_image = retrieve_latents (self .vae .encode (control_image ), generator = generator )
804
- control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
805
-
806
- height_control_image , width_control_image = control_image .shape [2 :]
807
- control_image = self ._pack_latents (
808
- control_image ,
809
- batch_size * num_images_per_prompt ,
810
- num_channels_latents ,
811
- height_control_image ,
812
- width_control_image ,
813
- )
803
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
804
+ controlnet_blocks_repeat = False if self .controlnet .input_hint_block is None else True
805
+ if self .controlnet .input_hint_block is None :
806
+ control_image = retrieve_latents (self .vae .encode (control_image ), generator = generator )
807
+ control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
808
+
809
+ height_control_image , width_control_image = control_image .shape [2 :]
810
+ control_image = self ._pack_latents (
811
+ control_image ,
812
+ batch_size * num_images_per_prompt ,
813
+ num_channels_latents ,
814
+ height_control_image ,
815
+ width_control_image ,
816
+ )
814
817
815
818
if control_mode is not None :
816
819
control_mode = torch .tensor (control_mode ).to (device , dtype = torch .long )
@@ -819,7 +822,9 @@ def __call__(
819
822
elif isinstance (self .controlnet , FluxMultiControlNetModel ):
820
823
control_images = []
821
824
822
- for control_image_ in control_image :
825
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
826
+ controlnet_blocks_repeat = False if self .controlnet .nets [0 ].input_hint_block is None else True
827
+ for i , control_image_ in enumerate (control_image ):
823
828
control_image_ = self .prepare_image (
824
829
image = control_image_ ,
825
830
width = width ,
@@ -831,17 +836,18 @@ def __call__(
831
836
)
832
837
height , width = control_image_ .shape [- 2 :]
833
838
834
- control_image_ = retrieve_latents (self .vae .encode (control_image_ ), generator = generator )
835
- control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
839
+ if self .controlnet .nets [0 ].input_hint_block is None :
840
+ control_image_ = retrieve_latents (self .vae .encode (control_image_ ), generator = generator )
841
+ control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
836
842
837
- height_control_image , width_control_image = control_image_ .shape [2 :]
838
- control_image_ = self ._pack_latents (
839
- control_image_ ,
840
- batch_size * num_images_per_prompt ,
841
- num_channels_latents ,
842
- height_control_image ,
843
- width_control_image ,
844
- )
843
+ height_control_image , width_control_image = control_image_ .shape [2 :]
844
+ control_image_ = self ._pack_latents (
845
+ control_image_ ,
846
+ batch_size * num_images_per_prompt ,
847
+ num_channels_latents ,
848
+ height_control_image ,
849
+ width_control_image ,
850
+ )
845
851
846
852
control_images .append (control_image_ )
847
853
@@ -955,6 +961,7 @@ def __call__(
955
961
img_ids = latent_image_ids ,
956
962
joint_attention_kwargs = self .joint_attention_kwargs ,
957
963
return_dict = False ,
964
+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
958
965
)[0 ]
959
966
960
967
latents_dtype = latents .dtype
0 commit comments