@@ -698,6 +698,48 @@ def test_stable_diffusion_img2img(self):
698
698
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
699
699
assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
700
700
701
+ def test_stable_diffusion_img2img_multiple_init_images (self ):
702
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
703
+ unet = self .dummy_cond_unet
704
+ scheduler = PNDMScheduler (skip_prk_steps = True )
705
+ vae = self .dummy_vae
706
+ bert = self .dummy_text_encoder
707
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
708
+
709
+ init_image = self .dummy_image .to (device ).repeat (2 , 1 , 1 , 1 )
710
+
711
+ # make sure here that pndm scheduler skips prk
712
+ sd_pipe = StableDiffusionImg2ImgPipeline (
713
+ unet = unet ,
714
+ scheduler = scheduler ,
715
+ vae = vae ,
716
+ text_encoder = bert ,
717
+ tokenizer = tokenizer ,
718
+ safety_checker = self .dummy_safety_checker ,
719
+ feature_extractor = self .dummy_extractor ,
720
+ )
721
+ sd_pipe = sd_pipe .to (device )
722
+ sd_pipe .set_progress_bar_config (disable = None )
723
+
724
+ prompt = 2 * ["A painting of a squirrel eating a burger" ]
725
+ generator = torch .Generator (device = device ).manual_seed (0 )
726
+ output = sd_pipe (
727
+ prompt ,
728
+ generator = generator ,
729
+ guidance_scale = 6.0 ,
730
+ num_inference_steps = 2 ,
731
+ output_type = "np" ,
732
+ init_image = init_image ,
733
+ )
734
+
735
+ image = output .images
736
+
737
+ image_slice = image [- 1 , - 3 :, - 3 :, - 1 ]
738
+
739
+ assert image .shape == (2 , 32 , 32 , 3 )
740
+ expected_slice = np .array ([0.5144 , 0.4447 , 0.4735 , 0.6676 , 0.5526 , 0.5454 , 0.645 , 0.5149 , 0.4689 ])
741
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
742
+
701
743
def test_stable_diffusion_img2img_k_lms (self ):
702
744
device = "cpu" # ensure determinism for the device-dependent torch.Generator
703
745
unet = self .dummy_cond_unet
0 commit comments