2727from ...schedulers import DPMSolverMultistepScheduler
2828from ...utils import (
2929 BACKENDS_MAPPING ,
30+ deprecate ,
3031 is_bs4_available ,
3132 is_ftfy_available ,
3233 logging ,
@@ -162,8 +163,10 @@ def encode_prompt(
162163 device : Optional [torch .device ] = None ,
163164 prompt_embeds : Optional [torch .FloatTensor ] = None ,
164165 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
166+ prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
167+ negative_prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
165168 clean_caption : bool = False ,
166- mask_feature : bool = True ,
169+ ** kwargs ,
167170 ):
168171 r"""
169172 Encodes the prompt into text encoder hidden states.
@@ -189,10 +192,11 @@ def encode_prompt(
189192 string.
190193 clean_caption (bool, defaults to `False`):
191194 If `True`, the function will preprocess and clean the provided caption before encoding.
192- mask_feature: (bool, defaults to `True`):
193- If `True`, the function will mask the text embeddings.
194195 """
195- embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
196+
197+ if "mask_feature" in kwargs :
198+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
199+ deprecate ("mask_feature" , "1.0.0" , deprecation_message , standard_warn = False )
196200
197201 if device is None :
198202 device = self ._execution_device
@@ -229,13 +233,11 @@ def encode_prompt(
229233 f" { max_length } tokens: { removed_text } "
230234 )
231235
232- attention_mask = text_inputs .attention_mask . to ( device )
233- prompt_embeds_attention_mask = attention_mask
236+ prompt_attention_mask = text_inputs .attention_mask
237+ prompt_attention_mask = prompt_attention_mask . to ( device )
234238
235- prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
239+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = prompt_attention_mask )
236240 prompt_embeds = prompt_embeds [0 ]
237- else :
238- prompt_embeds_attention_mask = torch .ones_like (prompt_embeds )
239241
240242 if self .text_encoder is not None :
241243 dtype = self .text_encoder .dtype
@@ -250,8 +252,8 @@ def encode_prompt(
250252 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
251253 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
252254 prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
253- prompt_embeds_attention_mask = prompt_embeds_attention_mask .view (bs_embed , - 1 )
254- prompt_embeds_attention_mask = prompt_embeds_attention_mask .repeat (num_images_per_prompt , 1 )
255+ prompt_attention_mask = prompt_attention_mask .view (bs_embed , - 1 )
256+ prompt_attention_mask = prompt_attention_mask .repeat (num_images_per_prompt , 1 )
255257
256258 # get unconditional embeddings for classifier free guidance
257259 if do_classifier_free_guidance and negative_prompt_embeds is None :
@@ -267,11 +269,11 @@ def encode_prompt(
267269 add_special_tokens = True ,
268270 return_tensors = "pt" ,
269271 )
270- attention_mask = uncond_input .attention_mask .to (device )
272+ negative_prompt_attention_mask = uncond_input .attention_mask
273+ negative_prompt_attention_mask = negative_prompt_attention_mask .to (device )
271274
272275 negative_prompt_embeds = self .text_encoder (
273- uncond_input .input_ids .to (device ),
274- attention_mask = attention_mask ,
276+ uncond_input .input_ids .to (device ), attention_mask = negative_prompt_attention_mask
275277 )
276278 negative_prompt_embeds = negative_prompt_embeds [0 ]
277279
@@ -284,23 +286,13 @@ def encode_prompt(
284286 negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
285287 negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
286288
287- # For classifier free guidance, we need to do two forward passes.
288- # Here we concatenate the unconditional and text embeddings into a single batch
289- # to avoid doing two forward passes
289+ negative_prompt_attention_mask = negative_prompt_attention_mask .view (bs_embed , - 1 )
290+ negative_prompt_attention_mask = negative_prompt_attention_mask .repeat (num_images_per_prompt , 1 )
290291 else :
291292 negative_prompt_embeds = None
293+ negative_prompt_attention_mask = None
292294
293- # Perform additional masking.
294- if mask_feature and not embeds_initially_provided :
295- prompt_embeds = prompt_embeds .unsqueeze (1 )
296- masked_prompt_embeds , keep_indices = self .mask_text_embeddings (prompt_embeds , prompt_embeds_attention_mask )
297- masked_prompt_embeds = masked_prompt_embeds .squeeze (1 )
298- masked_negative_prompt_embeds = (
299- negative_prompt_embeds [:, :keep_indices , :] if negative_prompt_embeds is not None else None
300- )
301- return masked_prompt_embeds , masked_negative_prompt_embeds
302-
303- return prompt_embeds , negative_prompt_embeds
295+ return prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask
304296
305297 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
306298 def prepare_extra_step_kwargs (self , generator , eta ):
@@ -329,6 +321,8 @@ def check_inputs(
329321 callback_steps ,
330322 prompt_embeds = None ,
331323 negative_prompt_embeds = None ,
324+ prompt_attention_mask = None ,
325+ negative_prompt_attention_mask = None ,
332326 ):
333327 if height % 8 != 0 or width % 8 != 0 :
334328 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -365,13 +359,25 @@ def check_inputs(
365359 f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
366360 )
367361
362+ if prompt_embeds is not None and prompt_attention_mask is None :
363+ raise ValueError ("Must provide `prompt_attention_mask` when specifying `prompt_embeds`." )
364+
365+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None :
366+ raise ValueError ("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." )
367+
368368 if prompt_embeds is not None and negative_prompt_embeds is not None :
369369 if prompt_embeds .shape != negative_prompt_embeds .shape :
370370 raise ValueError (
371371 "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
372372 f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
373373 f" { negative_prompt_embeds .shape } ."
374374 )
375+ if prompt_attention_mask .shape != negative_prompt_attention_mask .shape :
376+ raise ValueError (
377+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
378+ f" got: `prompt_attention_mask` { prompt_attention_mask .shape } != `negative_prompt_attention_mask`"
379+ f" { negative_prompt_attention_mask .shape } ."
380+ )
375381
376382 # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
377383 def _text_preprocessing (self , text , clean_caption = False ):
@@ -579,14 +585,16 @@ def __call__(
579585 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
580586 latents : Optional [torch .FloatTensor ] = None ,
581587 prompt_embeds : Optional [torch .FloatTensor ] = None ,
588+ prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
582589 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
590+ negative_prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
583591 output_type : Optional [str ] = "pil" ,
584592 return_dict : bool = True ,
585593 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
586594 callback_steps : int = 1 ,
587595 clean_caption : bool = True ,
588- mask_feature : bool = True ,
589596 use_resolution_binning : bool = True ,
597+ ** kwargs ,
590598 ) -> Union [ImagePipelineOutput , Tuple ]:
591599 """
592600 Function invoked when calling the pipeline for generation.
@@ -630,9 +638,12 @@ def __call__(
630638 prompt_embeds (`torch.FloatTensor`, *optional*):
631639 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632640 provided, text embeddings will be generated from `prompt` input argument.
641+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
633642 negative_prompt_embeds (`torch.FloatTensor`, *optional*):
634643 Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
635644 provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
645+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
646+ Pre-generated attention mask for negative text embeddings.
636647 output_type (`str`, *optional*, defaults to `"pil"`):
637648 The output format of the generate image. Choose between
638649 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -648,11 +659,10 @@ def __call__(
648659 Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
649660 be installed. If the dependencies are not installed, the embeddings will be created from the raw
650661 prompt.
651- mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
652- use_resolution_binning:
653- (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
654- closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
655- they are resized back to the requested resolution. Useful for generating non-square images.
662+ use_resolution_binning (`bool` defaults to `True`):
663+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
664+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
665+ the requested resolution. Useful for generating non-square images.
656666
657667 Examples:
658668
@@ -661,6 +671,9 @@ def __call__(
661671 If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
662672 returned where the first element is a list with the generated images
663673 """
674+ if "mask_feature" in kwargs :
675+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
676+ deprecate ("mask_feature" , "1.0.0" , deprecation_message , standard_warn = False )
664677 # 1. Check inputs. Raise error if not correct
665678 height = height or self .transformer .config .sample_size * self .vae_scale_factor
666679 width = width or self .transformer .config .sample_size * self .vae_scale_factor
@@ -669,7 +682,15 @@ def __call__(
669682 height , width = self .classify_height_width_bin (height , width , ratios = ASPECT_RATIO_1024_BIN )
670683
671684 self .check_inputs (
672- prompt , height , width , negative_prompt , callback_steps , prompt_embeds , negative_prompt_embeds
685+ prompt ,
686+ height ,
687+ width ,
688+ negative_prompt ,
689+ callback_steps ,
690+ prompt_embeds ,
691+ negative_prompt_embeds ,
692+ prompt_attention_mask ,
693+ negative_prompt_attention_mask ,
673694 )
674695
675696 # 2. Default height and width to transformer
@@ -688,19 +709,26 @@ def __call__(
688709 do_classifier_free_guidance = guidance_scale > 1.0
689710
690711 # 3. Encode input prompt
691- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
712+ (
713+ prompt_embeds ,
714+ prompt_attention_mask ,
715+ negative_prompt_embeds ,
716+ negative_prompt_attention_mask ,
717+ ) = self .encode_prompt (
692718 prompt ,
693719 do_classifier_free_guidance ,
694720 negative_prompt = negative_prompt ,
695721 num_images_per_prompt = num_images_per_prompt ,
696722 device = device ,
697723 prompt_embeds = prompt_embeds ,
698724 negative_prompt_embeds = negative_prompt_embeds ,
725+ prompt_attention_mask = prompt_attention_mask ,
726+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
699727 clean_caption = clean_caption ,
700- mask_feature = mask_feature ,
701728 )
702729 if do_classifier_free_guidance :
703730 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
731+ prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
704732
705733 # 4. Prepare timesteps
706734 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -758,6 +786,7 @@ def __call__(
758786 noise_pred = self .transformer (
759787 latent_model_input ,
760788 encoder_hidden_states = prompt_embeds ,
789+ encoder_attention_mask = prompt_attention_mask ,
761790 timestep = current_timestep ,
762791 added_cond_kwargs = added_cond_kwargs ,
763792 return_dict = False ,
0 commit comments