Skip to content

Commit 2048646

Browse files
committed
[Quantization] enable multi-backend bitsandbytes
1 parent fbff43a commit 2048646

File tree

7 files changed

+422
-80
lines changed

7 files changed

+422
-80
lines changed

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def __init__(self, quantization_config, **kwargs):
6161
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
6262

6363
def validate_environment(self, *args, **kwargs):
64-
if not torch.cuda.is_available():
65-
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
6664
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
6765
raise ImportError(
6866
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
@@ -72,6 +70,12 @@ def validate_environment(self, *args, **kwargs):
7270
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
7371
)
7472

73+
from ...utils import is_bitsandbytes_multi_backend_available
74+
from .utils import validate_bnb_backend_availability
75+
76+
bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
77+
validate_bnb_backend_availability(raise_exception=True)
78+
7579
if kwargs.get("from_flax", False):
7680
raise ValueError(
7781
"Converting into 4-bit weights from flax weights is currently not supported, please make"
@@ -87,7 +91,9 @@ def validate_environment(self, *args, **kwargs):
8791
device_map_without_no_convert = {
8892
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
8993
}
90-
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
94+
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
95+
pass
96+
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
9197
raise ValueError(
9298
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
9399
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
@@ -240,10 +246,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
240246
# Commenting this for discussions on the PR.
241247
# def update_device_map(self, device_map):
242248
# if device_map is None:
243-
# device_map = {"": torch.cuda.current_device()}
249+
# if torch.cuda.is_available():
250+
# device_map = {"": torch.cuda.current_device()}
251+
# elif is_torch_xpu_available():
252+
# device_map = {"": f"xpu:{torch.xpu.current_device()}"}
253+
# else:
254+
# device_map = {"": "cpu"}
244255
# logger.info(
245256
# "The device_map was not initialized. "
246-
# "Setting device_map to {'':torch.cuda.current_device()}. "
257+
# f"Setting device_map to {device_map}. "
247258
# "If you want to use the model for inference, please set device_map ='auto' "
248259
# )
249260
# return device_map
@@ -344,8 +355,6 @@ def __init__(self, quantization_config, **kwargs):
344355
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
345356

346357
def validate_environment(self, *args, **kwargs):
347-
if not torch.cuda.is_available():
348-
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
349358
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
350359
raise ImportError(
351360
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
@@ -355,6 +364,12 @@ def validate_environment(self, *args, **kwargs):
355364
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
356365
)
357366

367+
from ...utils import is_bitsandbytes_multi_backend_available
368+
from .utils import validate_bnb_backend_availability
369+
370+
bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
371+
validate_bnb_backend_availability(raise_exception=True)
372+
358373
if kwargs.get("from_flax", False):
359374
raise ValueError(
360375
"Converting into 8-bit weights from flax weights is currently not supported, please make"
@@ -370,7 +385,9 @@ def validate_environment(self, *args, **kwargs):
370385
device_map_without_no_convert = {
371386
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
372387
}
373-
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
388+
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
389+
pass
390+
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
374391
raise ValueError(
375392
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
376393
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
@@ -403,10 +420,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
403420
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
404421
# def update_device_map(self, device_map):
405422
# if device_map is None:
406-
# device_map = {"": torch.cuda.current_device()}
423+
# if torch.cuda.is_available():
424+
# device_map = {"": torch.cuda.current_device()}
425+
# elif is_torch_xpu_available():
426+
# device_map = {"": f"xpu:{torch.xpu.current_device()}"}
427+
# else:
428+
# device_map = {"": "cpu"}
407429
# logger.info(
408430
# "The device_map was not initialized. "
409-
# "Setting device_map to {'':torch.cuda.current_device()}. "
431+
# f"Setting device_map to {device_map}. "
410432
# "If you want to use the model for inference, please set device_map ='auto' "
411433
# )
412434
# return device_map

src/diffusers/quantizers/bitsandbytes/utils.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,22 @@
1616
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
1717
"""
1818

19+
import importlib
1920
import inspect
2021
from inspect import signature
2122
from typing import Union
2223

23-
from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
24+
from packaging import version
25+
26+
from ...utils import (
27+
get_available_devices,
28+
is_accelerate_available,
29+
is_bitsandbytes_available,
30+
is_bitsandbytes_multi_backend_available,
31+
is_ipex_available,
32+
is_torch_available,
33+
logging,
34+
)
2435
from ..quantization_config import QuantizationMethod
2536

2637

@@ -154,7 +165,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
154165

155166

156167
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
157-
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
168+
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
158169
"""
159170
Helper function to dequantize 4bit or 8bit bnb weights.
160171
@@ -172,7 +183,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
172183
logger.warning_once(
173184
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
174185
)
175-
return output_tensor
186+
return output_tensor.to(dtype)
176187

177188
if state.SCB is None:
178189
state.SCB = weight.SCB
@@ -183,7 +194,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
183194
if state.CxB is None:
184195
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
185196
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
186-
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
197+
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype)
187198

188199

189200
def _create_accelerate_new_hook(old_hook):
@@ -205,6 +216,7 @@ def _create_accelerate_new_hook(old_hook):
205216

206217
def _dequantize_and_replace(
207218
model,
219+
dtype,
208220
modules_to_not_convert=None,
209221
current_key_name=None,
210222
quantization_config=None,
@@ -244,7 +256,7 @@ def _dequantize_and_replace(
244256
else:
245257
state = None
246258

247-
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
259+
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state))
248260

249261
if bias is not None:
250262
new_module.bias = bias
@@ -263,6 +275,7 @@ def _dequantize_and_replace(
263275
if len(list(module.children())) > 0:
264276
_, has_been_replaced = _dequantize_and_replace(
265277
module,
278+
dtype,
266279
modules_to_not_convert,
267280
current_key_name,
268281
quantization_config,
@@ -280,6 +293,7 @@ def dequantize_and_replace(
280293
):
281294
model, has_been_replaced = _dequantize_and_replace(
282295
model,
296+
model.dtype,
283297
modules_to_not_convert=modules_to_not_convert,
284298
quantization_config=quantization_config,
285299
)
@@ -304,3 +318,80 @@ def _check_bnb_status(module) -> Union[bool, bool]:
304318
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
305319
)
306320
return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb
321+
322+
323+
def _validate_bnb_multi_backend_availability(raise_exception):
324+
import bitsandbytes as bnb
325+
326+
bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
327+
available_devices = get_available_devices()
328+
329+
if available_devices == {"cpu"} and not is_ipex_available():
330+
from importlib.util import find_spec
331+
332+
if find_spec("intel_extension_for_pytorch"):
333+
logger.warning(
334+
"You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible."
335+
)
336+
337+
available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment
338+
339+
if not available_devices.intersection(bnb_supported_devices):
340+
if raise_exception:
341+
bnb_supported_devices_with_info = set( # noqa: C401
342+
'"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)'
343+
if device == "cpu"
344+
else device
345+
for device in bnb_supported_devices
346+
)
347+
err_msg = (
348+
f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. "
349+
"Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
350+
)
351+
352+
logger.error(err_msg)
353+
raise RuntimeError(err_msg)
354+
355+
logger.warning("No supported devices found for bitsandbytes multi-backend.")
356+
return False
357+
358+
logger.debug("Multi-backend validation successful.")
359+
return True
360+
361+
362+
def _validate_bnb_cuda_backend_availability(raise_exception):
363+
if not is_torch_available():
364+
return False
365+
366+
import torch
367+
368+
if not torch.cuda.is_available():
369+
log_msg = (
370+
"CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. "
371+
"Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
372+
)
373+
if raise_exception:
374+
logger.error(log_msg)
375+
raise RuntimeError(log_msg)
376+
377+
logger.warning(log_msg)
378+
return False
379+
380+
logger.debug("CUDA backend validation successful.")
381+
return True
382+
383+
384+
def validate_bnb_backend_availability(raise_exception=False):
385+
"""
386+
Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
387+
"""
388+
if not is_bitsandbytes_available():
389+
if importlib.util.find_spec("bitsandbytes") and version.parse(
390+
importlib.metadata.version("bitsandbytes")
391+
) < version.parse("0.43.1"):
392+
return _validate_bnb_cuda_backend_availability(raise_exception)
393+
return False
394+
395+
if is_bitsandbytes_multi_backend_available():
396+
return _validate_bnb_multi_backend_availability(raise_exception)
397+
return _validate_bnb_cuda_backend_availability(raise_exception)

src/diffusers/utils/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515

1616
import os
17+
from functools import lru_cache
18+
from typing import FrozenSet
1719

1820
from packaging import version
1921

@@ -63,6 +65,7 @@
6365
is_accelerate_available,
6466
is_accelerate_version,
6567
is_bitsandbytes_available,
68+
is_bitsandbytes_multi_backend_available,
6669
is_bitsandbytes_version,
6770
is_bs4_available,
6871
is_flax_available,
@@ -73,6 +76,7 @@
7376
is_hf_hub_version,
7477
is_inflect_available,
7578
is_invisible_watermark_available,
79+
is_ipex_available,
7680
is_k_diffusion_available,
7781
is_k_diffusion_version,
7882
is_librosa_available,
@@ -87,10 +91,15 @@
8791
is_tensorboard_available,
8892
is_timm_available,
8993
is_torch_available,
94+
is_torch_cuda_available,
95+
is_torch_mlu_available,
96+
is_torch_mps_available,
97+
is_torch_musa_available,
9098
is_torch_npu_available,
9199
is_torch_version,
92100
is_torch_xla_available,
93101
is_torch_xla_version,
102+
is_torch_xpu_available,
94103
is_torchao_available,
95104
is_torchsde_available,
96105
is_torchvision_available,
@@ -139,3 +148,31 @@ def check_min_version(min_version):
139148
error_message = f"This example requires a minimum version of {min_version},"
140149
error_message += f" but the version found is {__version__}.\n"
141150
raise ImportError(error_message)
151+
152+
153+
@lru_cache()
154+
def get_available_devices() -> FrozenSet[str]:
155+
"""
156+
Returns a frozenset of devices available for the current PyTorch installation.
157+
"""
158+
devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
159+
160+
if is_torch_cuda_available():
161+
devices.add("cuda")
162+
163+
if is_torch_mps_available():
164+
devices.add("mps")
165+
166+
if is_torch_xpu_available():
167+
devices.add("xpu")
168+
169+
if is_torch_npu_available():
170+
devices.add("npu")
171+
172+
if is_torch_mlu_available():
173+
devices.add("mlu")
174+
175+
if is_torch_musa_available():
176+
devices.add("musa")
177+
178+
return frozenset(devices)

0 commit comments

Comments
 (0)