Skip to content

Commit e68735f

Browse files
committed
config migration: fpx, gemlite, uintx
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f48d9f19b754bcc193a2b0522d288bfff1b83089 ghstack-comment-id: 2649778077 Pull Request resolved: #1697
1 parent c3af5c0 commit e68735f

File tree

3 files changed

+145
-68
lines changed

3 files changed

+145
-68
lines changed

test/dtypes/test_affine_quantized.py

+1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype):
218218
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
219219
if isinstance(apply_quant, AOBaseConfig):
220220
quantize_(linear, apply_quant)
221+
ql = linear
221222
else:
222223
# TODO(#1690): delete this once config migration is done
223224
ql = apply_quant(linear)

test/quantization/test_quant_api.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
float8_dynamic_activation_float8_weight,
3434
float8_static_activation_float8_weight,
3535
float8_weight_only,
36+
fpx_weight_only,
37+
gemlite_uintx_weight_only,
3638
int4_dynamic_activation_int4_weight,
3739
int4_weight_only,
3840
int8_dynamic_activation_int4_weight,
3941
int8_dynamic_activation_int8_weight,
4042
int8_weight_only,
43+
uintx_weight_only,
4144
)
4245
from torchao.quantization.quant_primitives import MappingType
4346
from torchao.quantization.subclass import (
@@ -55,6 +58,13 @@
5558
unwrap_tensor_subclass,
5659
)
5760

61+
try:
62+
import gemlite # noqa: F401
63+
64+
has_gemlite = True
65+
except ModuleNotFoundError:
66+
has_gemlite = False
67+
5868

5969
def dynamic_quant(model, example_inputs):
6070
m = torch.export.export(model, example_inputs, strict=True).module()
@@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim):
804814
int8_dynamic_activation_int8_weight(),
805815
int8_dynamic_activation_int4_weight(),
806816
int8_weight_only(),
817+
fpx_weight_only(ebits=4, mbits=3),
818+
gemlite_uintx_weight_only(),
819+
uintx_weight_only(dtype=torch.uint4),
807820
],
808821
)
809822
def test_workflow_e2e_numerics(self, config):
@@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config):
827840
and is_sm_at_least_90()
828841
):
829842
return unittest.skip("only supported on CUDA capability 8.9, not greater")
843+
elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite:
844+
return unittest.skip("gemlite not available")
830845

831846
# scale has to be moved to cuda here because the parametrization init
832847
# code happens before gating for cuda availability
833848
if isinstance(config, float8_static_activation_float8_weight):
834849
config.scale = config.scale.to("cuda")
835850

851+
dtype = torch.bfloat16
852+
if isinstance(config, gemlite_uintx_weight_only):
853+
dtype = torch.float16
854+
836855
# set up inputs
837-
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
856+
x = torch.randn(128, 128, device="cuda", dtype=dtype)
838857
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
839858
# is that expected?
840-
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
859+
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
841860
m_q = copy.deepcopy(m_ref)
842861

843862
# quantize

torchao/quantization/quant_api.py

+123-66
Original file line numberDiff line numberDiff line change
@@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
729729
return module
730730

731731

732-
def gemlite_uintx_weight_only(
733-
group_size: Optional[int] = 64,
734-
bit_width: int = 4,
735-
packing_bitwidth: int = 32,
736-
contiguous: Optional[bool] = None,
737-
):
732+
@dataclass
733+
class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
738734
"""
739735
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
740736
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
@@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
747743
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
748744
"""
749745

746+
group_size: Optional[int] = 64
747+
bit_width: int = 4
748+
packing_bitwidth: int = 32
749+
contiguous: Optional[bool] = None
750+
751+
752+
# for BC
753+
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
754+
755+
756+
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
757+
def _gemlite_uintx_weight_only_transform(
758+
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
759+
):
760+
group_size = config.group_size
761+
bit_width = config.bit_width
762+
packing_bitwidth = config.packing_bitwidth
763+
contiguous = config.contiguous
764+
765+
weight = module.weight
766+
750767
from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs
751768

752769
use_hqq = True if bit_width == 4 else False
753-
apply_fn = lambda weight: to_affine_quantized_intx(
770+
new_weight = to_affine_quantized_intx(
754771
weight,
755772
**get_gemlite_aqt_kwargs(
756773
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
757774
),
758775
)
759-
return _get_linear_subclass_inserter(apply_fn)
776+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
777+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
778+
return module
760779

761780

762781
@dataclass
@@ -1379,9 +1398,10 @@ def _float8_static_activation_float8_weight_transform(
13791398
return module
13801399

13811400

1382-
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1401+
@dataclass
1402+
class UIntXWeightOnlyConfig(AOBaseConfig):
13831403
"""
1384-
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1404+
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
13851405
x is the number of bits specified by `dtype`
13861406
13871407
Args:
@@ -1391,6 +1411,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
13911411
`pack_dim`: the dimension we use for packing, defaults to -1
13921412
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
13931413
"""
1414+
1415+
dtype: torch.dtype
1416+
group_size: int = 64
1417+
pack_dim: int = -1
1418+
use_hqq: bool = False
1419+
1420+
1421+
# for BC
1422+
uintx_weight_only = UIntXWeightOnlyConfig
1423+
1424+
1425+
@register_quantize_module_handler(UIntXWeightOnlyConfig)
1426+
def _uintx_weight_only_transform(
1427+
module: torch.nn.Module, config: UIntXWeightOnlyConfig
1428+
):
1429+
dtype = config.dtype
1430+
group_size = config.group_size
1431+
pack_dim = config.pack_dim
1432+
use_hqq = config.use_hqq
1433+
1434+
weight = module.weight
1435+
13941436
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
13951437

13961438
SUPPORTED_DTYPES = {
@@ -1405,49 +1447,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
14051447
}
14061448
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"
14071449

1408-
def apply_uintx_weight_only_quant(weight, dtype):
1409-
mapping_type = MappingType.ASYMMETRIC
1410-
block_size = (1, group_size)
1411-
1412-
if use_hqq:
1413-
if dtype == torch.uint4:
1414-
logger.warn(
1415-
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1416-
)
1417-
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
1418-
dtype = torch.uint8
1419-
eps = None
1420-
zero_point_dtype = None
1421-
zero_point_domain = ZeroPointDomain.FLOAT
1422-
preserve_zero = False
1423-
_layout = PlainLayout()
1424-
else:
1425-
quant_min, quant_max = None, None
1426-
eps = torch.finfo(torch.float32).eps
1427-
zero_point_dtype = torch.int32
1428-
zero_point_domain = ZeroPointDomain.INT
1429-
preserve_zero = True
1430-
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
1450+
mapping_type = MappingType.ASYMMETRIC
1451+
block_size = (1, group_size)
14311452

1432-
return to_affine_quantized_intx(
1433-
weight,
1434-
mapping_type,
1435-
block_size,
1436-
dtype,
1437-
quant_min=quant_min,
1438-
quant_max=quant_max,
1439-
eps=eps,
1440-
zero_point_dtype=zero_point_dtype,
1441-
zero_point_domain=zero_point_domain,
1442-
preserve_zero=preserve_zero,
1443-
_layout=_layout,
1444-
use_hqq=use_hqq,
1445-
)
1453+
if use_hqq:
1454+
if dtype == torch.uint4:
1455+
logger.warn(
1456+
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1457+
)
1458+
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
1459+
dtype = torch.uint8
1460+
eps = None
1461+
zero_point_dtype = None
1462+
zero_point_domain = ZeroPointDomain.FLOAT
1463+
preserve_zero = False
1464+
_layout = PlainLayout()
1465+
else:
1466+
quant_min, quant_max = None, None
1467+
eps = torch.finfo(torch.float32).eps
1468+
zero_point_dtype = torch.int32
1469+
zero_point_domain = ZeroPointDomain.INT
1470+
preserve_zero = True
1471+
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
14461472

1447-
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)
1473+
new_weight = to_affine_quantized_intx(
1474+
weight,
1475+
mapping_type,
1476+
block_size,
1477+
dtype,
1478+
quant_min=quant_min,
1479+
quant_max=quant_max,
1480+
eps=eps,
1481+
zero_point_dtype=zero_point_dtype,
1482+
zero_point_domain=zero_point_domain,
1483+
preserve_zero=preserve_zero,
1484+
_layout=_layout,
1485+
use_hqq=use_hqq,
1486+
)
1487+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1488+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1489+
return module
14481490

14491491

1450-
def fpx_weight_only(ebits: int, mbits: int):
1492+
@dataclass
1493+
class FPXWeightOnlyConfig(AOBaseConfig):
14511494
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
14521495
e.g. fp6_e3_m2, fp6_e2_m3, ...
14531496
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
@@ -1458,26 +1501,40 @@ def fpx_weight_only(ebits: int, mbits: int):
14581501
in the future
14591502
"""
14601503

1461-
def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
1462-
from torchao.dtypes import to_affine_quantized_fpx
1463-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
1504+
ebits: int
1505+
mbits: int
14641506

1465-
assert (
1466-
weight.dim() == 2
1467-
), f"floatx only works for 2-d Tensor, got: {weight.dim()}"
1468-
out_dim, in_dim = weight.shape
1469-
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
1470-
logger.info(
1471-
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
1472-
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
1473-
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
1474-
)
1475-
return weight
14761507

1477-
_layout = FloatxTensorCoreLayout(ebits, mbits)
1478-
return to_affine_quantized_fpx(weight, _layout)
1508+
# for BC
1509+
fpx_weight_only = FPXWeightOnlyConfig
1510+
1511+
1512+
@register_quantize_module_handler(FPXWeightOnlyConfig)
1513+
def _fpx_weight_only_transform(
1514+
module: torch.nn.Module, config: FPXWeightOnlyConfig
1515+
) -> torch.nn.Module:
1516+
ebits = config.ebits
1517+
mbits = config.mbits
1518+
weight = module.weight
1519+
1520+
from torchao.dtypes import to_affine_quantized_fpx
1521+
from torchao.dtypes.floatx import FloatxTensorCoreLayout
14791522

1480-
return _get_linear_subclass_inserter(apply_quant_llm)
1523+
assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}"
1524+
out_dim, in_dim = weight.shape
1525+
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
1526+
logger.info(
1527+
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
1528+
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
1529+
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
1530+
)
1531+
return module
1532+
1533+
_layout = FloatxTensorCoreLayout(ebits, mbits)
1534+
new_weight = to_affine_quantized_fpx(weight, _layout)
1535+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1536+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1537+
return module
14811538

14821539

14831540
if TORCH_VERSION_AT_LEAST_2_5:

0 commit comments

Comments
 (0)