diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e3d58018..586a730b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -688,14 +688,14 @@ def forward( # classifier free guidance - cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device) - cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1') + keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) + keep_mask = rearrange(keep_mask, 'b -> b 1') - mask &= cond_prob_mask + mask &= keep_mask # whether text embedding is masked or not depends on the classifier free guidance conditional masking - mask = torch.cat((mask, cond_prob_mask), dim = 1) + mask = torch.cat((mask, keep_mask), dim = 1) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right @@ -1208,8 +1208,8 @@ def forward( # conditional dropout - cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) - cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1') + keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) + keep_mask = rearrange(keep_mask, 'b -> b 1 1') # mask out image embedding depending on condition dropout # for classifier free guidance @@ -1220,7 +1220,7 @@ def forward( image_tokens = self.image_to_cond(image_embed) image_tokens = torch.where( - cond_prob_mask, + keep_mask, image_tokens, self.null_image_embed ) @@ -1232,7 +1232,7 @@ def forward( if exists(text_encodings) and self.cond_on_text_encodings: text_tokens = self.text_to_cond(text_encodings) text_tokens = torch.where( - cond_prob_mask, + keep_mask, text_tokens, self.null_text_embed[:, :text_tokens.shape[1]] ) diff --git a/setup.py b/setup.py index c0c34b73..33196c22 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.71', + version = '0.0.72', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',