Skip to content

Commit 413689d

Browse files
authored
config migration: fpx, gemlite, uintx (#1697)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 6fe41c2 commit 413689d

File tree

5 files changed

+156
-76
lines changed

5 files changed

+156
-76
lines changed

test/dtypes/test_uintx.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype):
150150

151151
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
152152
# make sure it runs
153-
uintx_weight_only(dtype)(linear)
153+
quantize_(linear, uintx_weight_only(dtype))
154154
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
155155

156156

@@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype):
165165

166166
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
167167
# make sure it runs
168-
uintx_weight_only(dtype)(linear)
168+
quantize_(linear, uintx_weight_only(dtype))
169169
linear = torch.compile(linear)
170170
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))
171171

@@ -196,6 +196,6 @@ def test_uintx_model_size(dtype):
196196
)
197197
bf16_size = get_model_size_in_bytes(linear)
198198
# make sure it runs
199-
uintx_weight_only(dtype)(linear[0])
199+
quantize_(linear[0], uintx_weight_only(dtype))
200200
quantized_size = get_model_size_in_bytes(linear)
201201
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size

test/hqq/test_hqq_affine.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,10 @@ def _eval_hqq(dtype):
5353
dummy_linear.weight.data = W
5454
if dtype == torch.uint4:
5555
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
56-
quantize_(dummy_linear, config)
57-
q_tensor_hqq = dummy_linear.weight
5856
else:
59-
q_tensor_hqq = uintx_weight_only(
60-
dtype, group_size=max(block_size), use_hqq=True
61-
)(dummy_linear).weight
57+
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
58+
quantize_(dummy_linear, config)
59+
q_tensor_hqq = dummy_linear.weight
6260

6361
quant_linear_layer = torch.nn.Linear(
6462
W.shape[1], W.shape[0], bias=False, device=W.device

test/quantization/test_quant_api.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
float8_dynamic_activation_float8_weight,
3434
float8_static_activation_float8_weight,
3535
float8_weight_only,
36+
fpx_weight_only,
37+
gemlite_uintx_weight_only,
3638
int4_dynamic_activation_int4_weight,
3739
int4_weight_only,
3840
int8_dynamic_activation_int4_weight,
3941
int8_dynamic_activation_int8_weight,
4042
int8_weight_only,
43+
uintx_weight_only,
4144
)
4245
from torchao.quantization.quant_primitives import MappingType
4346
from torchao.quantization.subclass import (
@@ -55,6 +58,13 @@
5558
unwrap_tensor_subclass,
5659
)
5760

61+
try:
62+
import gemlite # noqa: F401
63+
64+
has_gemlite = True
65+
except ModuleNotFoundError:
66+
has_gemlite = False
67+
5868

5969
def dynamic_quant(model, example_inputs):
6070
m = torch.export.export(model, example_inputs, strict=True).module()
@@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim):
804814
int8_dynamic_activation_int8_weight(),
805815
int8_dynamic_activation_int4_weight(),
806816
int8_weight_only(),
817+
fpx_weight_only(ebits=4, mbits=3),
818+
gemlite_uintx_weight_only(),
819+
uintx_weight_only(dtype=torch.uint4),
807820
],
808821
)
809822
def test_workflow_e2e_numerics(self, config):
@@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config):
827840
and is_sm_at_least_90()
828841
):
829842
return unittest.skip("only supported on CUDA capability 8.9, not greater")
843+
elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite:
844+
return unittest.skip("gemlite not available")
830845

831846
# scale has to be moved to cuda here because the parametrization init
832847
# code happens before gating for cuda availability
833848
if isinstance(config, float8_static_activation_float8_weight):
834849
config.scale = config.scale.to("cuda")
835850

851+
dtype = torch.bfloat16
852+
if isinstance(config, gemlite_uintx_weight_only):
853+
dtype = torch.float16
854+
836855
# set up inputs
837-
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
856+
x = torch.randn(128, 128, device="cuda", dtype=dtype)
838857
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
839858
# is that expected?
840-
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
859+
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
841860
m_q = copy.deepcopy(m_ref)
842861

843862
# quantize

torchao/quantization/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@
4949
Float8DynamicActivationFloat8WeightConfig,
5050
Float8StaticActivationFloat8WeightConfig,
5151
Float8WeightOnlyConfig,
52+
FPXWeightOnlyConfig,
53+
GemliteUIntXWeightOnlyConfig,
5254
Int4DynamicActivationInt4WeightConfig,
5355
Int4WeightOnlyConfig,
5456
Int8DynamicActivationInt4WeightConfig,
5557
Int8DynamicActivationInt8WeightConfig,
5658
Int8WeightOnlyConfig,
59+
UIntXWeightOnlyConfig,
5760
float8_dynamic_activation_float8_weight,
5861
float8_static_activation_float8_weight,
5962
float8_weight_only,
@@ -135,6 +138,9 @@
135138
"Float8WeightOnlyConfig",
136139
"Float8DynamicActivationFloat8WeightConfig",
137140
"Float8StaticActivationFloat8WeightConfig",
141+
"UIntXWeightOnlyConfig",
142+
"FPXWeightOnlyConfig",
143+
"GemliteUIntXWeightOnlyConfig",
138144
# smooth quant - subject to change
139145
"get_scale",
140146
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

+123-66
Original file line numberDiff line numberDiff line change
@@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
729729
return module
730730

731731

732-
def gemlite_uintx_weight_only(
733-
group_size: Optional[int] = 64,
734-
bit_width: int = 4,
735-
packing_bitwidth: int = 32,
736-
contiguous: Optional[bool] = None,
737-
):
732+
@dataclass
733+
class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
738734
"""
739735
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
740736
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
@@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
747743
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
748744
"""
749745

746+
group_size: Optional[int] = 64
747+
bit_width: int = 4
748+
packing_bitwidth: int = 32
749+
contiguous: Optional[bool] = None
750+
751+
752+
# for BC
753+
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
754+
755+
756+
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
757+
def _gemlite_uintx_weight_only_transform(
758+
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
759+
):
760+
group_size = config.group_size
761+
bit_width = config.bit_width
762+
packing_bitwidth = config.packing_bitwidth
763+
contiguous = config.contiguous
764+
765+
weight = module.weight
766+
750767
from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs
751768

752769
use_hqq = True if bit_width == 4 else False
753-
apply_fn = lambda weight: to_affine_quantized_intx(
770+
new_weight = to_affine_quantized_intx(
754771
weight,
755772
**get_gemlite_aqt_kwargs(
756773
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
757774
),
758775
)
759-
return _get_linear_subclass_inserter(apply_fn)
776+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
777+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
778+
return module
760779

761780

762781
@dataclass
@@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform(
13801399
return module
13811400

13821401

1383-
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1402+
@dataclass
1403+
class UIntXWeightOnlyConfig(AOBaseConfig):
13841404
"""
1385-
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1405+
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
13861406
x is the number of bits specified by `dtype`
13871407
13881408
Args:
@@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
13921412
`pack_dim`: the dimension we use for packing, defaults to -1
13931413
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
13941414
"""
1415+
1416+
dtype: torch.dtype
1417+
group_size: int = 64
1418+
pack_dim: int = -1
1419+
use_hqq: bool = False
1420+
1421+
1422+
# for BC
1423+
uintx_weight_only = UIntXWeightOnlyConfig
1424+
1425+
1426+
@register_quantize_module_handler(UIntXWeightOnlyConfig)
1427+
def _uintx_weight_only_transform(
1428+
module: torch.nn.Module, config: UIntXWeightOnlyConfig
1429+
):
1430+
dtype = config.dtype
1431+
group_size = config.group_size
1432+
pack_dim = config.pack_dim
1433+
use_hqq = config.use_hqq
1434+
1435+
weight = module.weight
1436+
13951437
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
13961438

13971439
SUPPORTED_DTYPES = {
@@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
14061448
}
14071449
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"
14081450

1409-
def apply_uintx_weight_only_quant(weight, dtype):
1410-
mapping_type = MappingType.ASYMMETRIC
1411-
block_size = (1, group_size)
1412-
1413-
if use_hqq:
1414-
if dtype == torch.uint4:
1415-
logger.warn(
1416-
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1417-
)
1418-
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
1419-
dtype = torch.uint8
1420-
eps = None
1421-
zero_point_dtype = None
1422-
zero_point_domain = ZeroPointDomain.FLOAT
1423-
preserve_zero = False
1424-
_layout = PlainLayout()
1425-
else:
1426-
quant_min, quant_max = None, None
1427-
eps = torch.finfo(torch.float32).eps
1428-
zero_point_dtype = torch.int32
1429-
zero_point_domain = ZeroPointDomain.INT
1430-
preserve_zero = True
1431-
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
1451+
mapping_type = MappingType.ASYMMETRIC
1452+
block_size = (1, group_size)
14321453

1433-
return to_affine_quantized_intx(
1434-
weight,
1435-
mapping_type,
1436-
block_size,
1437-
dtype,
1438-
quant_min=quant_min,
1439-
quant_max=quant_max,
1440-
eps=eps,
1441-
zero_point_dtype=zero_point_dtype,
1442-
zero_point_domain=zero_point_domain,
1443-
preserve_zero=preserve_zero,
1444-
_layout=_layout,
1445-
use_hqq=use_hqq,
1446-
)
1454+
if use_hqq:
1455+
if dtype == torch.uint4:
1456+
logger.warn(
1457+
"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1458+
)
1459+
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
1460+
dtype = torch.uint8
1461+
eps = None
1462+
zero_point_dtype = None
1463+
zero_point_domain = ZeroPointDomain.FLOAT
1464+
preserve_zero = False
1465+
_layout = PlainLayout()
1466+
else:
1467+
quant_min, quant_max = None, None
1468+
eps = torch.finfo(torch.float32).eps
1469+
zero_point_dtype = torch.int32
1470+
zero_point_domain = ZeroPointDomain.INT
1471+
preserve_zero = True
1472+
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
14471473

1448-
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)
1474+
new_weight = to_affine_quantized_intx(
1475+
weight,
1476+
mapping_type,
1477+
block_size,
1478+
dtype,
1479+
quant_min=quant_min,
1480+
quant_max=quant_max,
1481+
eps=eps,
1482+
zero_point_dtype=zero_point_dtype,
1483+
zero_point_domain=zero_point_domain,
1484+
preserve_zero=preserve_zero,
1485+
_layout=_layout,
1486+
use_hqq=use_hqq,
1487+
)
1488+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1489+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1490+
return module
14491491

14501492

1451-
def fpx_weight_only(ebits: int, mbits: int):
1493+
@dataclass
1494+
class FPXWeightOnlyConfig(AOBaseConfig):
14521495
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
14531496
e.g. fp6_e3_m2, fp6_e2_m3, ...
14541497
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
@@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int):
14591502
in the future
14601503
"""
14611504

1462-
def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
1463-
from torchao.dtypes import to_affine_quantized_fpx
1464-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
1505+
ebits: int
1506+
mbits: int
14651507

1466-
assert (
1467-
weight.dim() == 2
1468-
), f"floatx only works for 2-d Tensor, got: {weight.dim()}"
1469-
out_dim, in_dim = weight.shape
1470-
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
1471-
logger.info(
1472-
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
1473-
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
1474-
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
1475-
)
1476-
return weight
14771508

1478-
_layout = FloatxTensorCoreLayout(ebits, mbits)
1479-
return to_affine_quantized_fpx(weight, _layout)
1509+
# for BC
1510+
fpx_weight_only = FPXWeightOnlyConfig
1511+
1512+
1513+
@register_quantize_module_handler(FPXWeightOnlyConfig)
1514+
def _fpx_weight_only_transform(
1515+
module: torch.nn.Module, config: FPXWeightOnlyConfig
1516+
) -> torch.nn.Module:
1517+
ebits = config.ebits
1518+
mbits = config.mbits
1519+
weight = module.weight
1520+
1521+
from torchao.dtypes import to_affine_quantized_fpx
1522+
from torchao.dtypes.floatx import FloatxTensorCoreLayout
14801523

1481-
return _get_linear_subclass_inserter(apply_quant_llm)
1524+
assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}"
1525+
out_dim, in_dim = weight.shape
1526+
if (in_dim % 64 != 0) or (out_dim % 256 != 0):
1527+
logger.info(
1528+
f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because "
1529+
f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} "
1530+
"expected in_dim % 64 == 0 and out_dim % 256 == 0"
1531+
)
1532+
return module
1533+
1534+
_layout = FloatxTensorCoreLayout(ebits, mbits)
1535+
new_weight = to_affine_quantized_fpx(weight, _layout)
1536+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1537+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1538+
return module
14821539

14831540

14841541
if TORCH_VERSION_AT_LEAST_2_5:

0 commit comments

Comments
 (0)