27
27
from ...schedulers import DPMSolverMultistepScheduler
28
28
from ...utils import (
29
29
BACKENDS_MAPPING ,
30
+ deprecate ,
30
31
is_bs4_available ,
31
32
is_ftfy_available ,
32
33
logging ,
@@ -162,8 +163,10 @@ def encode_prompt(
162
163
device : Optional [torch .device ] = None ,
163
164
prompt_embeds : Optional [torch .FloatTensor ] = None ,
164
165
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
166
+ prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
167
+ negative_prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
165
168
clean_caption : bool = False ,
166
- mask_feature : bool = True ,
169
+ ** kwargs ,
167
170
):
168
171
r"""
169
172
Encodes the prompt into text encoder hidden states.
@@ -189,10 +192,11 @@ def encode_prompt(
189
192
string.
190
193
clean_caption (bool, defaults to `False`):
191
194
If `True`, the function will preprocess and clean the provided caption before encoding.
192
- mask_feature: (bool, defaults to `True`):
193
- If `True`, the function will mask the text embeddings.
194
195
"""
195
- embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
196
+
197
+ if "mask_feature" in kwargs :
198
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
199
+ deprecate ("mask_feature" , "1.0.0" , deprecation_message , standard_warn = False )
196
200
197
201
if device is None :
198
202
device = self ._execution_device
@@ -229,13 +233,11 @@ def encode_prompt(
229
233
f" { max_length } tokens: { removed_text } "
230
234
)
231
235
232
- attention_mask = text_inputs .attention_mask . to ( device )
233
- prompt_embeds_attention_mask = attention_mask
236
+ prompt_attention_mask = text_inputs .attention_mask
237
+ prompt_attention_mask = prompt_attention_mask . to ( device )
234
238
235
- prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
239
+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = prompt_attention_mask )
236
240
prompt_embeds = prompt_embeds [0 ]
237
- else :
238
- prompt_embeds_attention_mask = torch .ones_like (prompt_embeds )
239
241
240
242
if self .text_encoder is not None :
241
243
dtype = self .text_encoder .dtype
@@ -250,8 +252,8 @@ def encode_prompt(
250
252
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
251
253
prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
252
254
prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
253
- prompt_embeds_attention_mask = prompt_embeds_attention_mask .view (bs_embed , - 1 )
254
- prompt_embeds_attention_mask = prompt_embeds_attention_mask .repeat (num_images_per_prompt , 1 )
255
+ prompt_attention_mask = prompt_attention_mask .view (bs_embed , - 1 )
256
+ prompt_attention_mask = prompt_attention_mask .repeat (num_images_per_prompt , 1 )
255
257
256
258
# get unconditional embeddings for classifier free guidance
257
259
if do_classifier_free_guidance and negative_prompt_embeds is None :
@@ -267,11 +269,11 @@ def encode_prompt(
267
269
add_special_tokens = True ,
268
270
return_tensors = "pt" ,
269
271
)
270
- attention_mask = uncond_input .attention_mask .to (device )
272
+ negative_prompt_attention_mask = uncond_input .attention_mask
273
+ negative_prompt_attention_mask = negative_prompt_attention_mask .to (device )
271
274
272
275
negative_prompt_embeds = self .text_encoder (
273
- uncond_input .input_ids .to (device ),
274
- attention_mask = attention_mask ,
276
+ uncond_input .input_ids .to (device ), attention_mask = negative_prompt_attention_mask
275
277
)
276
278
negative_prompt_embeds = negative_prompt_embeds [0 ]
277
279
@@ -284,23 +286,13 @@ def encode_prompt(
284
286
negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
285
287
negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
286
288
287
- # For classifier free guidance, we need to do two forward passes.
288
- # Here we concatenate the unconditional and text embeddings into a single batch
289
- # to avoid doing two forward passes
289
+ negative_prompt_attention_mask = negative_prompt_attention_mask .view (bs_embed , - 1 )
290
+ negative_prompt_attention_mask = negative_prompt_attention_mask .repeat (num_images_per_prompt , 1 )
290
291
else :
291
292
negative_prompt_embeds = None
293
+ negative_prompt_attention_mask = None
292
294
293
- # Perform additional masking.
294
- if mask_feature and not embeds_initially_provided :
295
- prompt_embeds = prompt_embeds .unsqueeze (1 )
296
- masked_prompt_embeds , keep_indices = self .mask_text_embeddings (prompt_embeds , prompt_embeds_attention_mask )
297
- masked_prompt_embeds = masked_prompt_embeds .squeeze (1 )
298
- masked_negative_prompt_embeds = (
299
- negative_prompt_embeds [:, :keep_indices , :] if negative_prompt_embeds is not None else None
300
- )
301
- return masked_prompt_embeds , masked_negative_prompt_embeds
302
-
303
- return prompt_embeds , negative_prompt_embeds
295
+ return prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask
304
296
305
297
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
306
298
def prepare_extra_step_kwargs (self , generator , eta ):
@@ -329,6 +321,8 @@ def check_inputs(
329
321
callback_steps ,
330
322
prompt_embeds = None ,
331
323
negative_prompt_embeds = None ,
324
+ prompt_attention_mask = None ,
325
+ negative_prompt_attention_mask = None ,
332
326
):
333
327
if height % 8 != 0 or width % 8 != 0 :
334
328
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -365,13 +359,25 @@ def check_inputs(
365
359
f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
366
360
)
367
361
362
+ if prompt_embeds is not None and prompt_attention_mask is None :
363
+ raise ValueError ("Must provide `prompt_attention_mask` when specifying `prompt_embeds`." )
364
+
365
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None :
366
+ raise ValueError ("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." )
367
+
368
368
if prompt_embeds is not None and negative_prompt_embeds is not None :
369
369
if prompt_embeds .shape != negative_prompt_embeds .shape :
370
370
raise ValueError (
371
371
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
372
372
f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
373
373
f" { negative_prompt_embeds .shape } ."
374
374
)
375
+ if prompt_attention_mask .shape != negative_prompt_attention_mask .shape :
376
+ raise ValueError (
377
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
378
+ f" got: `prompt_attention_mask` { prompt_attention_mask .shape } != `negative_prompt_attention_mask`"
379
+ f" { negative_prompt_attention_mask .shape } ."
380
+ )
375
381
376
382
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
377
383
def _text_preprocessing (self , text , clean_caption = False ):
@@ -579,14 +585,16 @@ def __call__(
579
585
generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
580
586
latents : Optional [torch .FloatTensor ] = None ,
581
587
prompt_embeds : Optional [torch .FloatTensor ] = None ,
588
+ prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
582
589
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
590
+ negative_prompt_attention_mask : Optional [torch .FloatTensor ] = None ,
583
591
output_type : Optional [str ] = "pil" ,
584
592
return_dict : bool = True ,
585
593
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
586
594
callback_steps : int = 1 ,
587
595
clean_caption : bool = True ,
588
- mask_feature : bool = True ,
589
596
use_resolution_binning : bool = True ,
597
+ ** kwargs ,
590
598
) -> Union [ImagePipelineOutput , Tuple ]:
591
599
"""
592
600
Function invoked when calling the pipeline for generation.
@@ -630,9 +638,12 @@ def __call__(
630
638
prompt_embeds (`torch.FloatTensor`, *optional*):
631
639
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632
640
provided, text embeddings will be generated from `prompt` input argument.
641
+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
633
642
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
634
643
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
635
644
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
645
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
646
+ Pre-generated attention mask for negative text embeddings.
636
647
output_type (`str`, *optional*, defaults to `"pil"`):
637
648
The output format of the generate image. Choose between
638
649
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -648,11 +659,10 @@ def __call__(
648
659
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
649
660
be installed. If the dependencies are not installed, the embeddings will be created from the raw
650
661
prompt.
651
- mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
652
- use_resolution_binning:
653
- (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
654
- closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
655
- they are resized back to the requested resolution. Useful for generating non-square images.
662
+ use_resolution_binning (`bool` defaults to `True`):
663
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
664
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
665
+ the requested resolution. Useful for generating non-square images.
656
666
657
667
Examples:
658
668
@@ -661,6 +671,9 @@ def __call__(
661
671
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
662
672
returned where the first element is a list with the generated images
663
673
"""
674
+ if "mask_feature" in kwargs :
675
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
676
+ deprecate ("mask_feature" , "1.0.0" , deprecation_message , standard_warn = False )
664
677
# 1. Check inputs. Raise error if not correct
665
678
height = height or self .transformer .config .sample_size * self .vae_scale_factor
666
679
width = width or self .transformer .config .sample_size * self .vae_scale_factor
@@ -669,7 +682,15 @@ def __call__(
669
682
height , width = self .classify_height_width_bin (height , width , ratios = ASPECT_RATIO_1024_BIN )
670
683
671
684
self .check_inputs (
672
- prompt , height , width , negative_prompt , callback_steps , prompt_embeds , negative_prompt_embeds
685
+ prompt ,
686
+ height ,
687
+ width ,
688
+ negative_prompt ,
689
+ callback_steps ,
690
+ prompt_embeds ,
691
+ negative_prompt_embeds ,
692
+ prompt_attention_mask ,
693
+ negative_prompt_attention_mask ,
673
694
)
674
695
675
696
# 2. Default height and width to transformer
@@ -688,19 +709,26 @@ def __call__(
688
709
do_classifier_free_guidance = guidance_scale > 1.0
689
710
690
711
# 3. Encode input prompt
691
- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
712
+ (
713
+ prompt_embeds ,
714
+ prompt_attention_mask ,
715
+ negative_prompt_embeds ,
716
+ negative_prompt_attention_mask ,
717
+ ) = self .encode_prompt (
692
718
prompt ,
693
719
do_classifier_free_guidance ,
694
720
negative_prompt = negative_prompt ,
695
721
num_images_per_prompt = num_images_per_prompt ,
696
722
device = device ,
697
723
prompt_embeds = prompt_embeds ,
698
724
negative_prompt_embeds = negative_prompt_embeds ,
725
+ prompt_attention_mask = prompt_attention_mask ,
726
+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
699
727
clean_caption = clean_caption ,
700
- mask_feature = mask_feature ,
701
728
)
702
729
if do_classifier_free_guidance :
703
730
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
731
+ prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
704
732
705
733
# 4. Prepare timesteps
706
734
self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -758,6 +786,7 @@ def __call__(
758
786
noise_pred = self .transformer (
759
787
latent_model_input ,
760
788
encoder_hidden_states = prompt_embeds ,
789
+ encoder_attention_mask = prompt_attention_mask ,
761
790
timestep = current_timestep ,
762
791
added_cond_kwargs = added_cond_kwargs ,
763
792
return_dict = False ,
0 commit comments