Skip to content

Commit

Permalink
make sure entire readme runs without errors
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 28, 2022
1 parent 36fb46a commit 8004633
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ decoder = Decoder(
).cuda()

for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()

# do above for many steps
Expand Down Expand Up @@ -861,25 +861,23 @@ unet1 = Unet(
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True,
).cuda()

unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()

decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
timesteps = 1000
).cuda()

decoder_trainer = DecoderTrainer(
Expand All @@ -904,8 +902,8 @@ for unet_number in (1, 2):
# after much training
# you can sample from the exponentially moving averaged unets as so

mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
mock_image_embed = torch.randn(32, 512).cuda()
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
```

### Diffusion Prior Training
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,7 +1831,7 @@ def cast_model_parameters(
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
cond_on_lowres_noise == self.cond_on_lowres_noise and \
lowres_noise_cond == self.lowres_noise_cond and \
channels_out == self.channels_out:
return self

Expand Down
8 changes: 6 additions & 2 deletions dalle2_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
accelerator,
accelerator = None,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
Expand All @@ -186,8 +186,12 @@ def __init__(
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
assert isinstance(accelerator, Accelerator)

ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)

if not exists(accelerator):
accelerator = Accelerator(**accelerator_kwargs)

# assign some helpful member vars

Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.1'
__version__ = '1.2.2'

0 comments on commit 8004633

Please sign in to comment.