Skip to content

config migration: fpx, gemlite, uintx #1697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


Expand All @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear = torch.compile(linear)
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))

Expand Down Expand Up @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype):
)
bf16_size = get_model_size_in_bytes(linear)
# make sure it runs
uintx_weight_only(dtype)(linear[0])
quantize_(linear[0], uintx_weight_only(dtype))
quantized_size = get_model_size_in_bytes(linear)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
8 changes: 3 additions & 5 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def _eval_hqq(dtype):
dummy_linear.weight.data = W
if dtype == torch.uint4:
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
)(dummy_linear).weight
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight

quant_linear_layer = torch.nn.Linear(
W.shape[1], W.shape[0], bias=False, device=W.device
Expand Down
23 changes: 21 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Expand All @@ -55,6 +58,13 @@
unwrap_tensor_subclass,
)

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True).module()
Expand Down Expand Up @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim):
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int4_weight(),
int8_weight_only(),
fpx_weight_only(ebits=4, mbits=3),
gemlite_uintx_weight_only(),
uintx_weight_only(dtype=torch.uint4),
],
)
def test_workflow_e2e_numerics(self, config):
Expand All @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config):
and is_sm_at_least_90()
):
return unittest.skip("only supported on CUDA capability 8.9, not greater")
elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite:
return unittest.skip("gemlite not available")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

dtype = torch.bfloat16
if isinstance(config, gemlite_uintx_weight_only):
dtype = torch.float16

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
x = torch.randn(128, 128, device="cuda", dtype=dtype)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
m_q = copy.deepcopy(m_ref)

# quantize
Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
Expand Down Expand Up @@ -135,6 +138,9 @@
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
"UIntXWeightOnlyConfig",
"FPXWeightOnlyConfig",
"GemliteUIntXWeightOnlyConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
189 changes: 123 additions & 66 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
return module


def gemlite_uintx_weight_only(
group_size: Optional[int] = 64,
bit_width: int = 4,
packing_bitwidth: int = 32,
contiguous: Optional[bool] = None,
):
@dataclass
class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
"""
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
Expand All @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
"""

group_size: Optional[int] = 64
bit_width: int = 4
packing_bitwidth: int = 32
contiguous: Optional[bool] = None


# for BC
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig


@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
def _gemlite_uintx_weight_only_transform(
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
):
group_size = config.group_size
bit_width = config.bit_width
packing_bitwidth = config.packing_bitwidth
contiguous = config.contiguous

weight = module.weight

from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs

use_hqq = True if bit_width == 4 else False
apply_fn = lambda weight: to_affine_quantized_intx(
new_weight = to_affine_quantized_intx(
weight,
**get_gemlite_aqt_kwargs(
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
),
)
return _get_linear_subclass_inserter(apply_fn)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


@dataclass
Expand Down Expand Up @@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform(
return module


def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
@dataclass
class UIntXWeightOnlyConfig(AOBaseConfig):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by `dtype`

Args:
Expand All @@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
`pack_dim`: the dimension we use for packing, defaults to -1
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
"""

dtype: torch.dtype
group_size: int = 64
pack_dim: int = -1
use_hqq: bool = False


# for BC
uintx_weight_only = UIntXWeightOnlyConfig


@register_quantize_module_handler(UIntXWeightOnlyConfig)
def _uintx_weight_only_transform(
module: torch.nn.Module, config: UIntXWeightOnlyConfig
):
dtype = config.dtype
group_size = config.group_size
pack_dim = config.pack_dim
use_hqq = config.use_hqq

weight = module.weight

from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS

SUPPORTED_DTYPES = {
Expand All @@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
}
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"

def apply_uintx_weight_only_quant(weight, dtype):
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)

if use_hqq:
if dtype == torch.uint4:
logger.warn(
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
)
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
_layout = PlainLayout()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)

return to_affine_quantized_intx(
weight,
mapping_type,
block_size,
dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
_layout=_layout,
use_hqq=use_hqq,
)
if use_hqq:
if dtype == torch.uint4:
logger.warn(
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
)
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
zero_point_dtype = None
zero_point_domain = ZeroPointDomain.FLOAT
preserve_zero = False
_layout = PlainLayout()
else:
quant_min, quant_max = None, None
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
preserve_zero = True
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)

return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
_layout=_layout,
use_hqq=use_hqq,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def fpx_weight_only(ebits: int, mbits: int):
@dataclass
class FPXWeightOnlyConfig(AOBaseConfig):
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
e.g. fp6_e3_m2, fp6_e2_m3, ...
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
Expand All @@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int):
in the future
"""

def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreLayout
ebits: int
mbits: int

assert (
weight.dim() == 2
), f"floatx only works for 2-d Tensor, got: {weight.dim()}"
out_dim, in_dim = weight.shape
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
logger.info(
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
)
return weight

_layout = FloatxTensorCoreLayout(ebits, mbits)
return to_affine_quantized_fpx(weight, _layout)
# for BC
fpx_weight_only = FPXWeightOnlyConfig


@register_quantize_module_handler(FPXWeightOnlyConfig)
def _fpx_weight_only_transform(
module: torch.nn.Module, config: FPXWeightOnlyConfig
) -> torch.nn.Module:
ebits = config.ebits
mbits = config.mbits
weight = module.weight

from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreLayout

return _get_linear_subclass_inserter(apply_quant_llm)
assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}"
out_dim, in_dim = weight.shape
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
logger.info(
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
)
return module

_layout = FloatxTensorCoreLayout(ebits, mbits)
new_weight = to_affine_quantized_fpx(weight, _layout)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


if TORCH_VERSION_AT_LEAST_2_5:
Expand Down
Loading