Skip to content

Commit b2bcc38

Browse files
authored
Add method to prepare extra step kwargs for scheduler in xFuserCogVideoXPipeline (#426)
1 parent 055b41f commit b2bcc38

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

xfuser/model_executor/pipelines/pipeline_cogvideox.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.distributed
6+
import inspect
67
from diffusers import CogVideoXPipeline
78
from diffusers.pipelines.cogvideo.pipeline_cogvideox import (
89
CogVideoXPipelineOutput,
@@ -404,6 +405,24 @@ def _init_sync_pipeline(
404405
)
405406
return latents, prompt_embeds, image_rotary_emb
406407

408+
409+
def prepare_extra_step_kwargs(self, generator, eta):
410+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
411+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
412+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
413+
# and should be between [0, 1]
414+
415+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.module.step).parameters.keys())
416+
extra_step_kwargs = {}
417+
if accepts_eta:
418+
extra_step_kwargs["eta"] = eta
419+
420+
# check if the scheduler accepts generator
421+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.module.step).parameters.keys())
422+
if accepts_generator:
423+
extra_step_kwargs["generator"] = generator
424+
return extra_step_kwargs
425+
407426
@property
408427
def interrupt(self):
409428
return self._interrupt

0 commit comments

Comments
 (0)