File tree 1 file changed +10
-0
lines changed
src/diffusers/pipelines/stable_audio
1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change 26
26
from ...models .embeddings import get_1d_rotary_pos_embed
27
27
from ...schedulers import EDMDPMSolverMultistepScheduler
28
28
from ...utils import (
29
+ is_torch_xla_available ,
29
30
logging ,
30
31
replace_example_docstring ,
31
32
)
32
33
from ...utils .torch_utils import randn_tensor
33
34
from ..pipeline_utils import AudioPipelineOutput , DiffusionPipeline
34
35
from .modeling_stable_audio import StableAudioProjectionModel
35
36
37
+ if is_torch_xla_available ():
38
+ import torch_xla .core .xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else :
42
+ XLA_AVAILABLE = False
36
43
37
44
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
38
45
@@ -725,6 +732,9 @@ def __call__(
725
732
if callback is not None and i % callback_steps == 0 :
726
733
step_idx = i // getattr (self .scheduler , "order" , 1 )
727
734
callback (step_idx , t , latents )
735
+
736
+ if XLA_AVAILABLE :
737
+ xm .mark_step ()
728
738
729
739
# 9. Post-processing
730
740
if not output_type == "latent" :
You can’t perform that action at this time.
0 commit comments