Skip to content

Commit fe79489

Browse files
authored
allow tensors in several schedulers step() call (huggingface#8905)
1 parent 461efc5 commit fe79489

7 files changed

+8
-8
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def _init_step_index(self, timestep):
674674
def step(
675675
self,
676676
model_output: torch.Tensor,
677-
timestep: int,
677+
timestep: Union[int, torch.Tensor],
678678
sample: torch.Tensor,
679679
return_dict: bool = True,
680680
) -> Union[SchedulerOutput, Tuple]:
@@ -685,7 +685,7 @@ def step(
685685
Args:
686686
model_output (`torch.Tensor`):
687687
The direct output from learned diffusion model.
688-
timestep (`float`):
688+
timestep (`int`):
689689
The current discrete timestep in the diffusion chain.
690690
sample (`torch.Tensor`):
691691
A current instance of a sample created by the diffusion process.

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def _init_step_index(self, timestep):
920920
def step(
921921
self,
922922
model_output: torch.Tensor,
923-
timestep: int,
923+
timestep: Union[int, torch.Tensor],
924924
sample: torch.Tensor,
925925
generator=None,
926926
variance_noise: Optional[torch.Tensor] = None,

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def _init_step_index(self, timestep):
787787
def step(
788788
self,
789789
model_output: torch.Tensor,
790-
timestep: int,
790+
timestep: Union[int, torch.Tensor],
791791
sample: torch.Tensor,
792792
generator=None,
793793
variance_noise: Optional[torch.Tensor] = None,

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ def _init_step_index(self, timestep):
927927
def step(
928928
self,
929929
model_output: torch.Tensor,
930-
timestep: int,
930+
timestep: Union[int, torch.Tensor],
931931
sample: torch.Tensor,
932932
generator=None,
933933
return_dict: bool = True,

src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _init_step_index(self, timestep):
594594
def step(
595595
self,
596596
model_output: torch.Tensor,
597-
timestep: int,
597+
timestep: Union[int, torch.Tensor],
598598
sample: torch.Tensor,
599599
generator=None,
600600
return_dict: bool = True,

src/diffusers/schedulers/scheduling_ipndm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _init_step_index(self, timestep):
138138
def step(
139139
self,
140140
model_output: torch.Tensor,
141-
timestep: int,
141+
timestep: Union[int, torch.Tensor],
142142
sample: torch.Tensor,
143143
return_dict: bool = True,
144144
) -> Union[SchedulerOutput, Tuple]:

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def _init_step_index(self, timestep):
822822
def step(
823823
self,
824824
model_output: torch.Tensor,
825-
timestep: int,
825+
timestep: Union[int, torch.Tensor],
826826
sample: torch.Tensor,
827827
return_dict: bool = True,
828828
) -> Union[SchedulerOutput, Tuple]:

0 commit comments

Comments
 (0)