Skip to content

Commit 11b3002

Browse files
Isotr0pysayakpaul
andauthored
Support views batch for panorama (#3632)
* support views batch for panorama * add entry for the new argument * format entry for the new argument * add view_batch_size test * fix batch test and a boundary condition * add more docstrings * fix a typos * fix typos * add: entry to the doc about view_batch_size. * Revert "add: entry to the doc about view_batch_size." This reverts commit a36aeaa. * add a tip on . --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 10f4ecd commit 11b3002

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

docs/source/en/api/pipelines/panorama.mdx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ image = pipe(prompt).images[0]
5252
image.save("dolomites.png")
5353
```
5454

55+
<Tip>
56+
57+
While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value.
58+
For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation
59+
and increase the VRAM usage.
60+
61+
</Tip>
62+
5563
## StableDiffusionPanoramaPipeline
5664
[[autodoc]] StableDiffusionPanoramaPipeline
5765
- __call__

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
451451

452452
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
453453
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
454+
# if panorama's height/width < window_size, num_blocks of height/width should return 1
454455
panorama_height /= 8
455456
panorama_width /= 8
456-
num_blocks_height = (panorama_height - window_size) // stride + 1
457-
num_blocks_width = (panorama_width - window_size) // stride + 1
457+
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
458+
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1
458459
total_num_blocks = int(num_blocks_height * num_blocks_width)
459460
views = []
460461
for i in range(total_num_blocks):
@@ -474,6 +475,7 @@ def __call__(
474475
width: Optional[int] = 2048,
475476
num_inference_steps: int = 50,
476477
guidance_scale: float = 7.5,
478+
view_batch_size: int = 1,
477479
negative_prompt: Optional[Union[str, List[str]]] = None,
478480
num_images_per_prompt: Optional[int] = 1,
479481
eta: float = 0.0,
@@ -508,6 +510,9 @@ def __call__(
508510
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
509511
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
510512
usually at the expense of lower image quality.
513+
view_batch_size (`int`, *optional*, defaults to 1):
514+
The batch size to denoise splited views. For some GPUs with high performance, higher view batch size
515+
can speedup the generation and increase the VRAM usage.
511516
negative_prompt (`str` or `List[str]`, *optional*):
512517
The prompt or prompts not to guide the image generation. If not defined, one has to pass
513518
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -609,8 +614,11 @@ def __call__(
609614
)
610615

611616
# 6. Define panorama grid and initialize views for synthesis.
617+
# prepare batch grid
612618
views = self.get_views(height, width)
613-
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
619+
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
620+
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
621+
614622
count = torch.zeros_like(latents)
615623
value = torch.zeros_like(latents)
616624

@@ -631,42 +639,55 @@ def __call__(
631639
# denoised (latent) crops are then averaged to produce the final latent
632640
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
633641
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
634-
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
642+
# Batch views denoise
643+
for j, batch_view in enumerate(views_batch):
644+
vb_size = len(batch_view)
635645
# get the latents corresponding to the current view coordinates
636-
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
646+
latents_for_view = torch.cat(
647+
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
648+
)
637649

638650
# rematch block's scheduler status
639651
self.scheduler.__dict__.update(views_scheduler_status[j])
640652

641653
# expand the latents if we are doing classifier free guidance
642654
latent_model_input = (
643-
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
655+
latents_for_view.repeat_interleave(2, dim=0)
656+
if do_classifier_free_guidance
657+
else latents_for_view
644658
)
645659
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
646660

661+
# repeat prompt_embeds for batch
662+
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
663+
647664
# predict the noise residual
648665
noise_pred = self.unet(
649666
latent_model_input,
650667
t,
651-
encoder_hidden_states=prompt_embeds,
668+
encoder_hidden_states=prompt_embeds_input,
652669
cross_attention_kwargs=cross_attention_kwargs,
653670
).sample
654671

655672
# perform guidance
656673
if do_classifier_free_guidance:
657-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
674+
noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
658675
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
659676

660677
# compute the previous noisy sample x_t -> x_t-1
661-
latents_view_denoised = self.scheduler.step(
678+
latents_denoised_batch = self.scheduler.step(
662679
noise_pred, t, latents_for_view, **extra_step_kwargs
663680
).prev_sample
664681

665682
# save views scheduler status after sample
666683
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
667684

668-
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
669-
count[:, :, h_start:h_end, w_start:w_end] += 1
685+
# extract value from batch
686+
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
687+
latents_denoised_batch.chunk(vb_size), batch_view
688+
):
689+
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
690+
count[:, :, h_start:h_end, w_start:w_end] += 1
670691

671692
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
672693
latents = torch.where(count > 0, value / count, value)

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_inference_batch_consistent(self):
131131

132132
# override to speed the overall test timing up.
133133
def test_inference_batch_single_identical(self):
134-
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
134+
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3)
135135

136136
def test_stable_diffusion_panorama_negative_prompt(self):
137137
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -152,6 +152,24 @@ def test_stable_diffusion_panorama_negative_prompt(self):
152152

153153
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
154154

155+
def test_stable_diffusion_panorama_views_batch(self):
156+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
157+
components = self.get_dummy_components()
158+
sd_pipe = StableDiffusionPanoramaPipeline(**components)
159+
sd_pipe = sd_pipe.to(device)
160+
sd_pipe.set_progress_bar_config(disable=None)
161+
162+
inputs = self.get_dummy_inputs(device)
163+
output = sd_pipe(**inputs, view_batch_size=2)
164+
image = output.images
165+
image_slice = image[0, -3:, -3:, -1]
166+
167+
assert image.shape == (1, 64, 64, 3)
168+
169+
expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
170+
171+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
172+
155173
def test_stable_diffusion_panorama_euler(self):
156174
device = "cpu" # ensure determinism for the device-dependent torch.Generator
157175
components = self.get_dummy_components()

0 commit comments

Comments
 (0)