@@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
451
451
452
452
def get_views (self , panorama_height , panorama_width , window_size = 64 , stride = 8 ):
453
453
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
454
+ # if panorama's height/width < window_size, num_blocks of height/width should return 1
454
455
panorama_height /= 8
455
456
panorama_width /= 8
456
- num_blocks_height = (panorama_height - window_size ) // stride + 1
457
- num_blocks_width = (panorama_width - window_size ) // stride + 1
457
+ num_blocks_height = (panorama_height - window_size ) // stride + 1 if panorama_height > window_size else 1
458
+ num_blocks_width = (panorama_width - window_size ) // stride + 1 if panorama_height > window_size else 1
458
459
total_num_blocks = int (num_blocks_height * num_blocks_width )
459
460
views = []
460
461
for i in range (total_num_blocks ):
@@ -474,6 +475,7 @@ def __call__(
474
475
width : Optional [int ] = 2048 ,
475
476
num_inference_steps : int = 50 ,
476
477
guidance_scale : float = 7.5 ,
478
+ view_batch_size : int = 1 ,
477
479
negative_prompt : Optional [Union [str , List [str ]]] = None ,
478
480
num_images_per_prompt : Optional [int ] = 1 ,
479
481
eta : float = 0.0 ,
@@ -508,6 +510,9 @@ def __call__(
508
510
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
509
511
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
510
512
usually at the expense of lower image quality.
513
+ view_batch_size (`int`, *optional*, defaults to 1):
514
+ The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
515
+ can speedup the generation and increase the VRAM usage.
511
516
negative_prompt (`str` or `List[str]`, *optional*):
512
517
The prompt or prompts not to guide the image generation. If not defined, one has to pass
513
518
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -609,8 +614,11 @@ def __call__(
609
614
)
610
615
611
616
# 6. Define panorama grid and initialize views for synthesis.
617
+ # prepare batch grid
612
618
views = self .get_views (height , width )
613
- views_scheduler_status = [copy .deepcopy (self .scheduler .__dict__ )] * len (views )
619
+ views_batch = [views [i : i + view_batch_size ] for i in range (0 , len (views ), view_batch_size )]
620
+ views_scheduler_status = [copy .deepcopy (self .scheduler .__dict__ )] * len (views_batch )
621
+
614
622
count = torch .zeros_like (latents )
615
623
value = torch .zeros_like (latents )
616
624
@@ -631,42 +639,55 @@ def __call__(
631
639
# denoised (latent) crops are then averaged to produce the final latent
632
640
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
633
641
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
634
- for j , (h_start , h_end , w_start , w_end ) in enumerate (views ):
642
+ # Batch views denoise
643
+ for j , batch_view in enumerate (views_batch ):
644
+ vb_size = len (batch_view )
635
645
# get the latents corresponding to the current view coordinates
636
- latents_for_view = latents [:, :, h_start :h_end , w_start :w_end ]
646
+ latents_for_view = torch .cat (
647
+ [latents [:, :, h_start :h_end , w_start :w_end ] for h_start , h_end , w_start , w_end in batch_view ]
648
+ )
637
649
638
650
# rematch block's scheduler status
639
651
self .scheduler .__dict__ .update (views_scheduler_status [j ])
640
652
641
653
# expand the latents if we are doing classifier free guidance
642
654
latent_model_input = (
643
- torch .cat ([latents_for_view ] * 2 ) if do_classifier_free_guidance else latents_for_view
655
+ latents_for_view .repeat_interleave (2 , dim = 0 )
656
+ if do_classifier_free_guidance
657
+ else latents_for_view
644
658
)
645
659
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
646
660
661
+ # repeat prompt_embeds for batch
662
+ prompt_embeds_input = torch .cat ([prompt_embeds ] * vb_size )
663
+
647
664
# predict the noise residual
648
665
noise_pred = self .unet (
649
666
latent_model_input ,
650
667
t ,
651
- encoder_hidden_states = prompt_embeds ,
668
+ encoder_hidden_states = prompt_embeds_input ,
652
669
cross_attention_kwargs = cross_attention_kwargs ,
653
670
).sample
654
671
655
672
# perform guidance
656
673
if do_classifier_free_guidance :
657
- noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
674
+ noise_pred_uncond , noise_pred_text = noise_pred [:: 2 ], noise_pred [ 1 :: 2 ]
658
675
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
659
676
660
677
# compute the previous noisy sample x_t -> x_t-1
661
- latents_view_denoised = self .scheduler .step (
678
+ latents_denoised_batch = self .scheduler .step (
662
679
noise_pred , t , latents_for_view , ** extra_step_kwargs
663
680
).prev_sample
664
681
665
682
# save views scheduler status after sample
666
683
views_scheduler_status [j ] = copy .deepcopy (self .scheduler .__dict__ )
667
684
668
- value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
669
- count [:, :, h_start :h_end , w_start :w_end ] += 1
685
+ # extract value from batch
686
+ for latents_view_denoised , (h_start , h_end , w_start , w_end ) in zip (
687
+ latents_denoised_batch .chunk (vb_size ), batch_view
688
+ ):
689
+ value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
690
+ count [:, :, h_start :h_end , w_start :w_end ] += 1
670
691
671
692
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
672
693
latents = torch .where (count > 0 , value / count , value )
0 commit comments