Skip to content

Commit

Permalink
always rederive the predicted noise from the clipped x0 for ddim + pr…
Browse files Browse the repository at this point in the history
…edict noise objective
  • Loading branch information
lucidrains committed Mar 5, 2023
1 parent cc58f75 commit 848e8a4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,7 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal

# predict noise

if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)

if time_next < 0:
image_embed = x_start
Expand Down Expand Up @@ -2975,10 +2972,7 @@ def p_sample_loop_ddim(

# predict noise

if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else:
pred_noise = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)

c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.12.2'
__version__ = '1.12.3'

0 comments on commit 848e8a4

Please sign in to comment.