Skip to content

Commit f1b46d6

Browse files
authored
Merge branch 'main' into layerwise-upcasting-hook
2 parents 7037133 + b0c8973 commit f1b46d6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def log_validation(
158158
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
159159
f" {args.validation_prompt}."
160160
)
161+
if args.enable_vae_tiling:
162+
pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
163+
161164
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
162165
pipeline = pipeline.to(accelerator.device)
163166
pipeline.set_progress_bar_config(disable=True)
@@ -597,6 +600,7 @@ def parse_args(input_args=None):
597600
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
598601
)
599602
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
603+
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
600604

601605
if input_args is not None:
602606
args = parser.parse_args(input_args)

0 commit comments

Comments
 (0)