Skip to content

Commit 39dfb7a

Browse files
authored
Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (huggingface#7244)
update
1 parent 1968356 commit 39dfb7a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ...models import StableCascadeUNet
2121
from ...schedulers import DDPMWuerstchenScheduler
22-
from ...utils import logging, replace_example_docstring
22+
from ...utils import is_torch_version, logging, replace_example_docstring
2323
from ...utils.torch_utils import randn_tensor
2424
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2525
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
@@ -361,6 +361,8 @@ def __call__(
361361
device = self._execution_device
362362
dtype = self.decoder.dtype
363363
self._guidance_scale = guidance_scale
364+
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
365+
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
364366

365367
# 1. Check inputs. Raise error if not correct
366368
self.check_inputs(

0 commit comments

Comments
 (0)