Skip to content

Commit

Permalink
fix decoder needing separate conditional dropping probabilities for i…
Browse files Browse the repository at this point in the history
…mage embeddings and text encodings, thanks to @xiankgx !
  • Loading branch information
lucidrains committed Apr 30, 2022
1 parent 721a444 commit f19c99e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)
Expand Down Expand Up @@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)
Expand Down Expand Up @@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()

Expand Down Expand Up @@ -558,7 +561,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()

Expand Down Expand Up @@ -669,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)
Expand Down
24 changes: 15 additions & 9 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def forward_with_cond_scale(
if cond_scale == 1:
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(
Expand All @@ -1185,7 +1185,8 @@ def forward(
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
Expand All @@ -1204,8 +1205,10 @@ def forward(

# conditional dropout

keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)

image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')

# mask out image embedding depending on condition dropout
# for classifier free guidance
Expand All @@ -1216,7 +1219,7 @@ def forward(
image_tokens = self.image_to_cond(image_embed)

image_tokens = torch.where(
keep_mask,
image_keep_mask,
image_tokens,
self.null_image_embed
)
Expand All @@ -1228,7 +1231,7 @@ def forward(
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
keep_mask,
text_keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
)
Expand Down Expand Up @@ -1318,7 +1321,8 @@ def __init__(
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
Expand Down Expand Up @@ -1402,7 +1406,8 @@ def __init__(

# classifier free guidance

self.cond_drop_prob = cond_drop_prob
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob

def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
Expand Down Expand Up @@ -1484,7 +1489,8 @@ def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)

target = noise if not predict_x_start else x_start
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.73',
version = '0.0.74',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit f19c99e

Please sign in to comment.