Skip to content

Commit 3d70777

Browse files
[Sana-4K] (#10537)
* [Sana 4K] add 4K support for Sana * [Sana-4K] fix SanaPAGPipeline * add VAE automatically tiling function; * set clean_caption to False; * add warnings for VAE OOM. * style --------- Co-authored-by: yiyixuxu <[email protected]>
1 parent 6b72784 commit 3d70777

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -41,6 +42,7 @@
4142
ASPECT_RATIO_1024_BIN,
4243
)
4344
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
45+
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
4446
from .pag_utils import PAGMixin
4547

4648

@@ -639,7 +641,7 @@ def __call__(
639641
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
640642
output_type: Optional[str] = "pil",
641643
return_dict: bool = True,
642-
clean_caption: bool = True,
644+
clean_caption: bool = False,
643645
use_resolution_binning: bool = True,
644646
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
645647
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -755,7 +757,9 @@ def __call__(
755757
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
756758

757759
if use_resolution_binning:
758-
if self.transformer.config.sample_size == 64:
760+
if self.transformer.config.sample_size == 128:
761+
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
762+
elif self.transformer.config.sample_size == 64:
759763
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
760764
elif self.transformer.config.sample_size == 32:
761765
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
@@ -912,7 +916,14 @@ def __call__(
912916
image = latents
913917
else:
914918
latents = latents.to(self.vae.dtype)
915-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
919+
try:
920+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
921+
except torch.cuda.OutOfMemoryError as e:
922+
warnings.warn(
923+
f"{e}. \n"
924+
f"Try to use VAE tiling for large images. For example: \n"
925+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
926+
)
916927
if use_resolution_binning:
917928
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
918929

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
import re
1818
import urllib.parse as ul
19+
import warnings
1920
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2021

2122
import torch
@@ -953,7 +954,14 @@ def __call__(
953954
image = latents
954955
else:
955956
latents = latents.to(self.vae.dtype)
956-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
957+
try:
958+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
959+
except torch.cuda.OutOfMemoryError as e:
960+
warnings.warn(
961+
f"{e}. \n"
962+
f"Try to use VAE tiling for large images. For example: \n"
963+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
964+
)
957965
if use_resolution_binning:
958966
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
959967

0 commit comments

Comments
 (0)