Skip to content

Commit e8da75d

Browse files
sayakpaulbeniz
andauthored
[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components (#9840)
* allow device placement when using bnb quantization. * warning. * tests * fixes * docs. * require accelerate version. * remove print. * revert to() * tests * fixes * fix: missing AutoencoderKL lora adapter (#9807) * fix: missing AutoencoderKL lora adapter * fix --------- Co-authored-by: Sayak Paul <[email protected]> * fixes * fix condition test * updates * updates * remove is_offloaded. * fixes * better * empty --------- Co-authored-by: Emmanuel Benazera <[email protected]>
1 parent 8a450c3 commit e8da75d

File tree

3 files changed

+91
-6
lines changed

3 files changed

+91
-6
lines changed

src/diffusers/pipelines/pipeline_utils.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
if is_torch_npu_available():
6767
import torch_npu # noqa: F401
6868

69-
7069
from .pipeline_loading_utils import (
7170
ALL_IMPORTABLE_CLASSES,
7271
CONNECTED_PIPES_KEYS,
@@ -388,6 +387,7 @@ def to(self, *args, **kwargs):
388387
)
389388

390389
device = device or device_arg
390+
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
391391

392392
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
393393
def module_is_sequentially_offloaded(module):
@@ -410,10 +410,16 @@ def module_is_offloaded(module):
410410
pipeline_is_sequentially_offloaded = any(
411411
module_is_sequentially_offloaded(module) for _, module in self.components.items()
412412
)
413-
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
414-
raise ValueError(
415-
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
416-
)
413+
if device and torch.device(device).type == "cuda":
414+
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
415+
raise ValueError(
416+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
417+
)
418+
# PR: https://github.com/huggingface/accelerate/pull/3223/
419+
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
420+
raise ValueError(
421+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
422+
)
417423

418424
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
419425
if is_pipeline_device_mapped:

tests/quantization/bnb/test_4bit.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
import unittest
1919

2020
import numpy as np
21+
import pytest
2122
import safetensors.torch
2223

2324
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
24-
from diffusers.utils import logging
25+
from diffusers.utils import is_accelerate_version, logging
2526
from diffusers.utils.testing_utils import (
2627
CaptureLogger,
2728
is_bitsandbytes_available,
@@ -47,6 +48,7 @@ def get_some_linear_layer(model):
4748

4849

4950
if is_transformers_available():
51+
from transformers import BitsAndBytesConfig as BnbConfig
5052
from transformers import T5EncoderModel
5153

5254
if is_torch_available():
@@ -483,6 +485,47 @@ def test_moving_to_cpu_throws_warning(self):
483485

484486
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
485487

488+
@pytest.mark.xfail(
489+
condition=is_accelerate_version("<=", "1.1.1"),
490+
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
491+
strict=True,
492+
)
493+
def test_pipeline_cuda_placement_works_with_nf4(self):
494+
transformer_nf4_config = BitsAndBytesConfig(
495+
load_in_4bit=True,
496+
bnb_4bit_quant_type="nf4",
497+
bnb_4bit_compute_dtype=torch.float16,
498+
)
499+
transformer_4bit = SD3Transformer2DModel.from_pretrained(
500+
self.model_name,
501+
subfolder="transformer",
502+
quantization_config=transformer_nf4_config,
503+
torch_dtype=torch.float16,
504+
)
505+
text_encoder_3_nf4_config = BnbConfig(
506+
load_in_4bit=True,
507+
bnb_4bit_quant_type="nf4",
508+
bnb_4bit_compute_dtype=torch.float16,
509+
)
510+
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
511+
self.model_name,
512+
subfolder="text_encoder_3",
513+
quantization_config=text_encoder_3_nf4_config,
514+
torch_dtype=torch.float16,
515+
)
516+
# CUDA device placement works.
517+
pipeline_4bit = DiffusionPipeline.from_pretrained(
518+
self.model_name,
519+
transformer=transformer_4bit,
520+
text_encoder_3=text_encoder_3_4bit,
521+
torch_dtype=torch.float16,
522+
).to("cuda")
523+
524+
# Check if inference works.
525+
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
526+
527+
del pipeline_4bit
528+
486529

487530
@require_transformers_version_greater("4.44.0")
488531
class SlowBnb4BitFluxTests(Base4bitTests):

tests/quantization/bnb/test_mixed_int8.py

+36
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import unittest
1818

1919
import numpy as np
20+
import pytest
2021

2122
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
23+
from diffusers.utils import is_accelerate_version
2224
from diffusers.utils.testing_utils import (
2325
CaptureLogger,
2426
is_bitsandbytes_available,
@@ -44,6 +46,7 @@ def get_some_linear_layer(model):
4446

4547

4648
if is_transformers_available():
49+
from transformers import BitsAndBytesConfig as BnbConfig
4750
from transformers import T5EncoderModel
4851

4952
if is_torch_available():
@@ -432,6 +435,39 @@ def test_generate_quality_dequantize(self):
432435
output_type="np",
433436
).images
434437

438+
@pytest.mark.xfail(
439+
condition=is_accelerate_version("<=", "1.1.1"),
440+
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
441+
strict=True,
442+
)
443+
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
444+
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
445+
transformer_8bit = SD3Transformer2DModel.from_pretrained(
446+
self.model_name,
447+
subfolder="transformer",
448+
quantization_config=transformer_8bit_config,
449+
torch_dtype=torch.float16,
450+
)
451+
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
452+
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
453+
self.model_name,
454+
subfolder="text_encoder_3",
455+
quantization_config=text_encoder_3_8bit_config,
456+
torch_dtype=torch.float16,
457+
)
458+
# CUDA device placement works.
459+
pipeline_8bit = DiffusionPipeline.from_pretrained(
460+
self.model_name,
461+
transformer=transformer_8bit,
462+
text_encoder_3=text_encoder_3_8bit,
463+
torch_dtype=torch.float16,
464+
).to("cuda")
465+
466+
# Check if inference works.
467+
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
468+
469+
del pipeline_8bit
470+
435471

436472
@require_transformers_version_greater("4.44.0")
437473
class SlowBnb8bitFluxTests(Base8bitTests):

0 commit comments

Comments
 (0)