Skip to content

Commit

Permalink
fix a bug with classifier free guidance, thanks to @xiankgx again!
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 30, 2022
1 parent a389f81 commit 0d1c07c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,14 +688,14 @@ def forward(

# classifier free guidance

cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')

mask &= cond_prob_mask
mask &= keep_mask

# whether text embedding is masked or not depends on the classifier free guidance conditional masking

mask = torch.cat((mask, cond_prob_mask), dim = 1)
mask = torch.cat((mask, keep_mask), dim = 1)

# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
Expand Down Expand Up @@ -1208,8 +1208,8 @@ def forward(

# conditional dropout

cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1 1')

# mask out image embedding depending on condition dropout
# for classifier free guidance
Expand All @@ -1220,7 +1220,7 @@ def forward(
image_tokens = self.image_to_cond(image_embed)

image_tokens = torch.where(
cond_prob_mask,
keep_mask,
image_tokens,
self.null_image_embed
)
Expand All @@ -1232,7 +1232,7 @@ def forward(
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.71',
version = '0.0.72',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 0d1c07c

Please sign in to comment.