|
16 | 16 | import inspect
|
17 | 17 | import re
|
18 | 18 | import urllib.parse as ul
|
| 19 | +import warnings |
19 | 20 | from typing import Callable, Dict, List, Optional, Tuple, Union
|
20 | 21 |
|
21 | 22 | import torch
|
|
41 | 42 | ASPECT_RATIO_1024_BIN,
|
42 | 43 | )
|
43 | 44 | from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
|
| 45 | +from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN |
44 | 46 | from .pag_utils import PAGMixin
|
45 | 47 |
|
46 | 48 |
|
@@ -639,7 +641,7 @@ def __call__(
|
639 | 641 | negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
640 | 642 | output_type: Optional[str] = "pil",
|
641 | 643 | return_dict: bool = True,
|
642 |
| - clean_caption: bool = True, |
| 644 | + clean_caption: bool = False, |
643 | 645 | use_resolution_binning: bool = True,
|
644 | 646 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
645 | 647 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
@@ -755,7 +757,9 @@ def __call__(
|
755 | 757 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
756 | 758 |
|
757 | 759 | if use_resolution_binning:
|
758 |
| - if self.transformer.config.sample_size == 64: |
| 760 | + if self.transformer.config.sample_size == 128: |
| 761 | + aspect_ratio_bin = ASPECT_RATIO_4096_BIN |
| 762 | + elif self.transformer.config.sample_size == 64: |
759 | 763 | aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
760 | 764 | elif self.transformer.config.sample_size == 32:
|
761 | 765 | aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
@@ -912,7 +916,14 @@ def __call__(
|
912 | 916 | image = latents
|
913 | 917 | else:
|
914 | 918 | latents = latents.to(self.vae.dtype)
|
915 |
| - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| 919 | + try: |
| 920 | + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| 921 | + except torch.cuda.OutOfMemoryError as e: |
| 922 | + warnings.warn( |
| 923 | + f"{e}. \n" |
| 924 | + f"Try to use VAE tiling for large images. For example: \n" |
| 925 | + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" |
| 926 | + ) |
916 | 927 | if use_resolution_binning:
|
917 | 928 | image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
918 | 929 |
|
|
0 commit comments