Skip to content

Commit

Permalink
Add method to prepare extra step kwargs for scheduler in xFuserCogVid…
Browse files Browse the repository at this point in the history
…eoXPipeline

This update introduces the `prepare_extra_step_kwargs` method, which prepares additional keyword arguments for the scheduler step based on its signature. It checks for the presence of 'eta' and 'generator' parameters to ensure compatibility with different schedulers. This enhancement improves the flexibility and usability of the pipeline.
  • Loading branch information
LazyBusyYang committed Jan 8, 2025
1 parent 1c441d9 commit ec182a3
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions xfuser/model_executor/pipelines/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.distributed
import inspect
from diffusers import CogVideoXPipeline
from diffusers.pipelines.cogvideo.pipeline_cogvideox import (
CogVideoXPipelineOutput,
Expand Down Expand Up @@ -404,6 +405,24 @@ def _init_sync_pipeline(
)
return latents, prompt_embeds, image_rotary_emb


def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]

accepts_eta = "eta" in set(inspect.signature(self.scheduler.module.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta

# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.module.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

@property
def interrupt(self):
return self._interrupt
Expand Down

0 comments on commit ec182a3

Please sign in to comment.