Skip to content

Commit 48207d6

Browse files
authored
[Scheduler] fix: EDM schedulers when using the exp sigma schedule. (huggingface#8385)
* fix: euledm when using the exp sigma schedule. * fix-copies * remove print. * reduce friction * yiyi's suggestioms
1 parent 2f6f426 commit 48207d6

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,13 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
243243

244244
self.num_inference_steps = num_inference_steps
245245

246-
ramp = np.linspace(0, 1, self.num_inference_steps)
246+
ramp = torch.linspace(0, 1, self.num_inference_steps)
247247
if self.config.sigma_schedule == "karras":
248248
sigmas = self._compute_karras_sigmas(ramp)
249249
elif self.config.sigma_schedule == "exponential":
250250
sigmas = self._compute_exponential_sigmas(ramp)
251251

252-
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
252+
sigmas = sigmas.to(dtype=torch.float32, device=device)
253253
self.timesteps = self.precondition_noise(sigmas)
254254

255255
if self.config.final_sigmas_type == "sigma_min":
@@ -283,7 +283,6 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
283283
min_inv_rho = sigma_min ** (1 / rho)
284284
max_inv_rho = sigma_max ** (1 / rho)
285285
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
286-
287286
return sigmas
288287

289288
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from dataclasses import dataclass
1717
from typing import Optional, Tuple, Union
1818

19-
import numpy as np
2019
import torch
2120

2221
from ..configuration_utils import ConfigMixin, register_to_config
@@ -210,13 +209,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
210209
"""
211210
self.num_inference_steps = num_inference_steps
212211

213-
ramp = np.linspace(0, 1, self.num_inference_steps)
212+
ramp = torch.linspace(0, 1, self.num_inference_steps)
214213
if self.config.sigma_schedule == "karras":
215214
sigmas = self._compute_karras_sigmas(ramp)
216215
elif self.config.sigma_schedule == "exponential":
217216
sigmas = self._compute_exponential_sigmas(ramp)
218217

219-
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
218+
sigmas = sigmas.to(dtype=torch.float32, device=device)
220219
self.timesteps = self.precondition_noise(sigmas)
221220

222221
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -234,7 +233,6 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
234233
min_inv_rho = sigma_min ** (1 / rho)
235234
max_inv_rho = sigma_max ** (1 / rho)
236235
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
237-
238236
return sigmas
239237

240238
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:

0 commit comments

Comments
 (0)