diff --git a/README.md b/README.md index a27de1c5..d72779a3 100644 --- a/README.md +++ b/README.md @@ -396,7 +396,7 @@ decoder = Decoder( ).cuda() for unet_number in (1, 2): - loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much + loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss.backward() # do above for many steps @@ -861,25 +861,23 @@ unet1 = Unet( text_embed_dim = 512, cond_dim = 128, channels = 3, - dim_mults=(1, 2, 4, 8) + dim_mults=(1, 2, 4, 8), + cond_on_text_encodings = True, ).cuda() unet2 = Unet( dim = 16, image_embed_dim = 512, - text_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults = (1, 2, 4, 8, 16), - cond_on_text_encodings = True ).cuda() decoder = Decoder( unet = (unet1, unet2), image_sizes = (128, 256), clip = clip, - timesteps = 1000, - condition_on_text_encodings = True + timesteps = 1000 ).cuda() decoder_trainer = DecoderTrainer( @@ -904,8 +902,8 @@ for unet_number in (1, 2): # after much training # you can sample from the exponentially moving averaged unets as so -mock_image_embed = torch.randn(4, 512).cuda() -images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) +mock_image_embed = torch.randn(32, 512).cuda() +images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256) ``` ### Diffusion Prior Training diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c912cfbb..f1136f8d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1831,7 +1831,7 @@ def cast_model_parameters( channels == self.channels and \ cond_on_image_embeds == self.cond_on_image_embeds and \ cond_on_text_encodings == self.cond_on_text_encodings and \ - cond_on_lowres_noise == self.cond_on_lowres_noise and \ + lowres_noise_cond == self.lowres_noise_cond and \ channels_out == self.channels_out: return self diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 87ed3b5b..db44c71d 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module): def __init__( self, diffusion_prior, - accelerator, + accelerator = None, use_ema = True, lr = 3e-4, wd = 1e-2, @@ -186,8 +186,12 @@ def __init__( ): super().__init__() assert isinstance(diffusion_prior, DiffusionPrior) - assert isinstance(accelerator, Accelerator) + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs) + + if not exists(accelerator): + accelerator = Accelerator(**accelerator_kwargs) # assign some helpful member vars diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 3f262a63..923b9879 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.1' +__version__ = '1.2.2'