Skip to content

Commit f19c99e

Browse files
committed
fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx !
1 parent 721a444 commit f19c99e

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ decoder = Decoder(
110110
unet = unet,
111111
clip = clip,
112112
timesteps = 100,
113-
cond_drop_prob = 0.2
113+
image_cond_drop_prob = 0.1,
114+
text_cond_drop_prob = 0.5
114115
).cuda()
115116

116117
# mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
229230
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
230231
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
231232
timesteps = 1000,
232-
cond_drop_prob = 0.2
233+
image_cond_drop_prob = 0.1,
234+
text_cond_drop_prob = 0.5
233235
).cuda()
234236

235237
# mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
348350
image_sizes = (128, 256),
349351
clip = clip,
350352
timesteps = 100,
351-
cond_drop_prob = 0.2,
353+
image_cond_drop_prob = 0.1,
354+
text_cond_drop_prob = 0.5,
352355
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
353356
).cuda()
354357

@@ -558,7 +561,8 @@ decoder = Decoder(
558561
image_sizes = (128, 256),
559562
clip = clip,
560563
timesteps = 100,
561-
cond_drop_prob = 0.2,
564+
image_cond_drop_prob = 0.1,
565+
text_cond_drop_prob = 0.5,
562566
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
563567
).cuda()
564568

@@ -669,7 +673,8 @@ decoder = Decoder(
669673
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
670674
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
671675
timesteps = 100,
672-
cond_drop_prob = 0.2
676+
image_cond_drop_prob = 0.1,
677+
text_cond_drop_prob = 0.5
673678
).cuda()
674679

675680
# mock images (get a lot of this)

dalle2_pytorch/dalle2_pytorch.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def forward_with_cond_scale(
11741174
if cond_scale == 1:
11751175
return logits
11761176

1177-
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1177+
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
11781178
return null_logits + (logits - null_logits) * cond_scale
11791179

11801180
def forward(
@@ -1185,7 +1185,8 @@ def forward(
11851185
image_embed,
11861186
lowres_cond_img = None,
11871187
text_encodings = None,
1188-
cond_drop_prob = 0.,
1188+
image_cond_drop_prob = 0.,
1189+
text_cond_drop_prob = 0.,
11891190
blur_sigma = None,
11901191
blur_kernel_size = None
11911192
):
@@ -1204,8 +1205,10 @@ def forward(
12041205

12051206
# conditional dropout
12061207

1207-
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1208-
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
1208+
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
1209+
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
1210+
1211+
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
12091212

12101213
# mask out image embedding depending on condition dropout
12111214
# for classifier free guidance
@@ -1216,7 +1219,7 @@ def forward(
12161219
image_tokens = self.image_to_cond(image_embed)
12171220

12181221
image_tokens = torch.where(
1219-
keep_mask,
1222+
image_keep_mask,
12201223
image_tokens,
12211224
self.null_image_embed
12221225
)
@@ -1228,7 +1231,7 @@ def forward(
12281231
if exists(text_encodings) and self.cond_on_text_encodings:
12291232
text_tokens = self.text_to_cond(text_encodings)
12301233
text_tokens = torch.where(
1231-
keep_mask,
1234+
text_keep_mask,
12321235
text_tokens,
12331236
self.null_text_embed[:, :text_tokens.shape[1]]
12341237
)
@@ -1318,7 +1321,8 @@ def __init__(
13181321
clip,
13191322
vae = tuple(),
13201323
timesteps = 1000,
1321-
cond_drop_prob = 0.2,
1324+
image_cond_drop_prob = 0.1,
1325+
text_cond_drop_prob = 0.5,
13221326
loss_type = 'l1',
13231327
beta_schedule = 'cosine',
13241328
predict_x_start = False,
@@ -1402,7 +1406,8 @@ def __init__(
14021406

14031407
# classifier free guidance
14041408

1405-
self.cond_drop_prob = cond_drop_prob
1409+
self.image_cond_drop_prob = image_cond_drop_prob
1410+
self.text_cond_drop_prob = text_cond_drop_prob
14061411

14071412
def get_unet(self, unet_number):
14081413
assert 0 < unet_number <= len(self.unets)
@@ -1484,7 +1489,8 @@ def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None,
14841489
image_embed = image_embed,
14851490
text_encodings = text_encodings,
14861491
lowres_cond_img = lowres_cond_img,
1487-
cond_drop_prob = self.cond_drop_prob
1492+
image_cond_drop_prob = self.image_cond_drop_prob,
1493+
text_cond_drop_prob = self.text_cond_drop_prob,
14881494
)
14891495

14901496
target = noise if not predict_x_start else x_start

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'dream = dalle2_pytorch.cli:dream'
1111
],
1212
},
13-
version = '0.0.73',
13+
version = '0.0.74',
1414
license='MIT',
1515
description = 'DALL-E 2',
1616
author = 'Phil Wang',

0 commit comments

Comments
 (0)