Skip to content

Commit 5d83951

Browse files
[PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to generate is more than 1 (huggingface#5752)
* does this fix things? * attention mask use * attention mask order * better masking. * add: tesrt * remove mask_featur * test * debug * fix: tests * deprecate mask_feature * add deprecation test * add slow test * add print statements to retrieve the assertion values. * fix for the 1024 fast tes * fix tesy * fix the remaining * Apply suggestions from code review * more debug --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 0f2bb9b commit 5d83951

File tree

1 file changed

+66
-37
lines changed

1 file changed

+66
-37
lines changed

pipelines/pixart_alpha/pipeline_pixart_alpha.py

+66-37
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ...schedulers import DPMSolverMultistepScheduler
2828
from ...utils import (
2929
BACKENDS_MAPPING,
30+
deprecate,
3031
is_bs4_available,
3132
is_ftfy_available,
3233
logging,
@@ -162,8 +163,10 @@ def encode_prompt(
162163
device: Optional[torch.device] = None,
163164
prompt_embeds: Optional[torch.FloatTensor] = None,
164165
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
166+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
167+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
165168
clean_caption: bool = False,
166-
mask_feature: bool = True,
169+
**kwargs,
167170
):
168171
r"""
169172
Encodes the prompt into text encoder hidden states.
@@ -189,10 +192,11 @@ def encode_prompt(
189192
string.
190193
clean_caption (bool, defaults to `False`):
191194
If `True`, the function will preprocess and clean the provided caption before encoding.
192-
mask_feature: (bool, defaults to `True`):
193-
If `True`, the function will mask the text embeddings.
194195
"""
195-
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
196+
197+
if "mask_feature" in kwargs:
198+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
199+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
196200

197201
if device is None:
198202
device = self._execution_device
@@ -229,13 +233,11 @@ def encode_prompt(
229233
f" {max_length} tokens: {removed_text}"
230234
)
231235

232-
attention_mask = text_inputs.attention_mask.to(device)
233-
prompt_embeds_attention_mask = attention_mask
236+
prompt_attention_mask = text_inputs.attention_mask
237+
prompt_attention_mask = prompt_attention_mask.to(device)
234238

235-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
239+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
236240
prompt_embeds = prompt_embeds[0]
237-
else:
238-
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
239241

240242
if self.text_encoder is not None:
241243
dtype = self.text_encoder.dtype
@@ -250,8 +252,8 @@ def encode_prompt(
250252
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
251253
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
252254
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
253-
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
254-
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
255+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
256+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
255257

256258
# get unconditional embeddings for classifier free guidance
257259
if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -267,11 +269,11 @@ def encode_prompt(
267269
add_special_tokens=True,
268270
return_tensors="pt",
269271
)
270-
attention_mask = uncond_input.attention_mask.to(device)
272+
negative_prompt_attention_mask = uncond_input.attention_mask
273+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
271274

272275
negative_prompt_embeds = self.text_encoder(
273-
uncond_input.input_ids.to(device),
274-
attention_mask=attention_mask,
276+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
275277
)
276278
negative_prompt_embeds = negative_prompt_embeds[0]
277279

@@ -284,23 +286,13 @@ def encode_prompt(
284286
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
285287
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
286288

287-
# For classifier free guidance, we need to do two forward passes.
288-
# Here we concatenate the unconditional and text embeddings into a single batch
289-
# to avoid doing two forward passes
289+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
290+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
290291
else:
291292
negative_prompt_embeds = None
293+
negative_prompt_attention_mask = None
292294

293-
# Perform additional masking.
294-
if mask_feature and not embeds_initially_provided:
295-
prompt_embeds = prompt_embeds.unsqueeze(1)
296-
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
297-
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
298-
masked_negative_prompt_embeds = (
299-
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
300-
)
301-
return masked_prompt_embeds, masked_negative_prompt_embeds
302-
303-
return prompt_embeds, negative_prompt_embeds
295+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
304296

305297
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
306298
def prepare_extra_step_kwargs(self, generator, eta):
@@ -329,6 +321,8 @@ def check_inputs(
329321
callback_steps,
330322
prompt_embeds=None,
331323
negative_prompt_embeds=None,
324+
prompt_attention_mask=None,
325+
negative_prompt_attention_mask=None,
332326
):
333327
if height % 8 != 0 or width % 8 != 0:
334328
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -365,13 +359,25 @@ def check_inputs(
365359
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
366360
)
367361

362+
if prompt_embeds is not None and prompt_attention_mask is None:
363+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
364+
365+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
366+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
367+
368368
if prompt_embeds is not None and negative_prompt_embeds is not None:
369369
if prompt_embeds.shape != negative_prompt_embeds.shape:
370370
raise ValueError(
371371
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
372372
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
373373
f" {negative_prompt_embeds.shape}."
374374
)
375+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
376+
raise ValueError(
377+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
378+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
379+
f" {negative_prompt_attention_mask.shape}."
380+
)
375381

376382
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
377383
def _text_preprocessing(self, text, clean_caption=False):
@@ -579,14 +585,16 @@ def __call__(
579585
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
580586
latents: Optional[torch.FloatTensor] = None,
581587
prompt_embeds: Optional[torch.FloatTensor] = None,
588+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
582589
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
590+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
583591
output_type: Optional[str] = "pil",
584592
return_dict: bool = True,
585593
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
586594
callback_steps: int = 1,
587595
clean_caption: bool = True,
588-
mask_feature: bool = True,
589596
use_resolution_binning: bool = True,
597+
**kwargs,
590598
) -> Union[ImagePipelineOutput, Tuple]:
591599
"""
592600
Function invoked when calling the pipeline for generation.
@@ -630,9 +638,12 @@ def __call__(
630638
prompt_embeds (`torch.FloatTensor`, *optional*):
631639
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632640
provided, text embeddings will be generated from `prompt` input argument.
641+
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
633642
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
634643
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
635644
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
645+
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
646+
Pre-generated attention mask for negative text embeddings.
636647
output_type (`str`, *optional*, defaults to `"pil"`):
637648
The output format of the generate image. Choose between
638649
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -648,11 +659,10 @@ def __call__(
648659
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
649660
be installed. If the dependencies are not installed, the embeddings will be created from the raw
650661
prompt.
651-
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
652-
use_resolution_binning:
653-
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
654-
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
655-
they are resized back to the requested resolution. Useful for generating non-square images.
662+
use_resolution_binning (`bool` defaults to `True`):
663+
If set to `True`, the requested height and width are first mapped to the closest resolutions using
664+
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
665+
the requested resolution. Useful for generating non-square images.
656666
657667
Examples:
658668
@@ -661,6 +671,9 @@ def __call__(
661671
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
662672
returned where the first element is a list with the generated images
663673
"""
674+
if "mask_feature" in kwargs:
675+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
676+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
664677
# 1. Check inputs. Raise error if not correct
665678
height = height or self.transformer.config.sample_size * self.vae_scale_factor
666679
width = width or self.transformer.config.sample_size * self.vae_scale_factor
@@ -669,7 +682,15 @@ def __call__(
669682
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
670683

671684
self.check_inputs(
672-
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
685+
prompt,
686+
height,
687+
width,
688+
negative_prompt,
689+
callback_steps,
690+
prompt_embeds,
691+
negative_prompt_embeds,
692+
prompt_attention_mask,
693+
negative_prompt_attention_mask,
673694
)
674695

675696
# 2. Default height and width to transformer
@@ -688,19 +709,26 @@ def __call__(
688709
do_classifier_free_guidance = guidance_scale > 1.0
689710

690711
# 3. Encode input prompt
691-
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
712+
(
713+
prompt_embeds,
714+
prompt_attention_mask,
715+
negative_prompt_embeds,
716+
negative_prompt_attention_mask,
717+
) = self.encode_prompt(
692718
prompt,
693719
do_classifier_free_guidance,
694720
negative_prompt=negative_prompt,
695721
num_images_per_prompt=num_images_per_prompt,
696722
device=device,
697723
prompt_embeds=prompt_embeds,
698724
negative_prompt_embeds=negative_prompt_embeds,
725+
prompt_attention_mask=prompt_attention_mask,
726+
negative_prompt_attention_mask=negative_prompt_attention_mask,
699727
clean_caption=clean_caption,
700-
mask_feature=mask_feature,
701728
)
702729
if do_classifier_free_guidance:
703730
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
731+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
704732

705733
# 4. Prepare timesteps
706734
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -758,6 +786,7 @@ def __call__(
758786
noise_pred = self.transformer(
759787
latent_model_input,
760788
encoder_hidden_states=prompt_embeds,
789+
encoder_attention_mask=prompt_attention_mask,
761790
timestep=current_timestep,
762791
added_cond_kwargs=added_cond_kwargs,
763792
return_dict=False,

0 commit comments

Comments
 (0)