Skip to content

Commit 0d1c07c

Browse files
committed
fix a bug with classifier free guidance, thanks to @xiankgx again!
1 parent a389f81 commit 0d1c07c

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

dalle2_pytorch/dalle2_pytorch.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -688,14 +688,14 @@ def forward(
688688

689689
# classifier free guidance
690690

691-
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
692-
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
691+
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
692+
keep_mask = rearrange(keep_mask, 'b -> b 1')
693693

694-
mask &= cond_prob_mask
694+
mask &= keep_mask
695695

696696
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
697697

698-
mask = torch.cat((mask, cond_prob_mask), dim = 1)
698+
mask = torch.cat((mask, keep_mask), dim = 1)
699699

700700
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
701701
# but let's just do it right
@@ -1208,8 +1208,8 @@ def forward(
12081208

12091209
# conditional dropout
12101210

1211-
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
1212-
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
1211+
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1212+
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
12131213

12141214
# mask out image embedding depending on condition dropout
12151215
# for classifier free guidance
@@ -1220,7 +1220,7 @@ def forward(
12201220
image_tokens = self.image_to_cond(image_embed)
12211221

12221222
image_tokens = torch.where(
1223-
cond_prob_mask,
1223+
keep_mask,
12241224
image_tokens,
12251225
self.null_image_embed
12261226
)
@@ -1232,7 +1232,7 @@ def forward(
12321232
if exists(text_encodings) and self.cond_on_text_encodings:
12331233
text_tokens = self.text_to_cond(text_encodings)
12341234
text_tokens = torch.where(
1235-
cond_prob_mask,
1235+
keep_mask,
12361236
text_tokens,
12371237
self.null_text_embed[:, :text_tokens.shape[1]]
12381238
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'dream = dalle2_pytorch.cli:dream'
1111
],
1212
},
13-
version = '0.0.71',
13+
version = '0.0.72',
1414
license='MIT',
1515
description = 'DALL-E 2',
1616
author = 'Phil Wang',

0 commit comments

Comments
 (0)