Skip to content

Commit 9ff7243

Browse files
author
fancy45daddy
authored
add torch_xla support in pipeline_stable_audio.py (#10109)
Update pipeline_stable_audio.py
1 parent c1926ce commit 9ff7243

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

+10
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,20 @@
2626
from ...models.embeddings import get_1d_rotary_pos_embed
2727
from ...schedulers import EDMDPMSolverMultistepScheduler
2828
from ...utils import (
29+
is_torch_xla_available,
2930
logging,
3031
replace_example_docstring,
3132
)
3233
from ...utils.torch_utils import randn_tensor
3334
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
3435
from .modeling_stable_audio import StableAudioProjectionModel
3536

37+
if is_torch_xla_available():
38+
import torch_xla.core.xla_model as xm
39+
40+
XLA_AVAILABLE = True
41+
else:
42+
XLA_AVAILABLE = False
3643

3744
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3845

@@ -725,6 +732,9 @@ def __call__(
725732
if callback is not None and i % callback_steps == 0:
726733
step_idx = i // getattr(self.scheduler, "order", 1)
727734
callback(step_idx, t, latents)
735+
736+
if XLA_AVAILABLE:
737+
xm.mark_step()
728738

729739
# 9. Post-processing
730740
if not output_type == "latent":

0 commit comments

Comments
 (0)