@@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
729
729
return module
730
730
731
731
732
- def gemlite_uintx_weight_only (
733
- group_size : Optional [int ] = 64 ,
734
- bit_width : int = 4 ,
735
- packing_bitwidth : int = 32 ,
736
- contiguous : Optional [bool ] = None ,
737
- ):
732
+ @dataclass
733
+ class GemliteUIntXWeightOnlyConfig (AOBaseConfig ):
738
734
"""
739
735
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
740
736
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
@@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
747
743
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
748
744
"""
749
745
746
+ group_size : Optional [int ] = 64
747
+ bit_width : int = 4
748
+ packing_bitwidth : int = 32
749
+ contiguous : Optional [bool ] = None
750
+
751
+
752
+ # for BC
753
+ gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
754
+
755
+
756
+ @register_quantize_module_handler (GemliteUIntXWeightOnlyConfig )
757
+ def _gemlite_uintx_weight_only_transform (
758
+ module : torch .nn .Module , config : GemliteUIntXWeightOnlyConfig
759
+ ):
760
+ group_size = config .group_size
761
+ bit_width = config .bit_width
762
+ packing_bitwidth = config .packing_bitwidth
763
+ contiguous = config .contiguous
764
+
765
+ weight = module .weight
766
+
750
767
from torchao .dtypes .uintx .gemlite_layout import get_gemlite_aqt_kwargs
751
768
752
769
use_hqq = True if bit_width == 4 else False
753
- apply_fn = lambda weight : to_affine_quantized_intx (
770
+ new_weight = to_affine_quantized_intx (
754
771
weight ,
755
772
** get_gemlite_aqt_kwargs (
756
773
weight , group_size , bit_width , packing_bitwidth , contiguous , use_hqq
757
774
),
758
775
)
759
- return _get_linear_subclass_inserter (apply_fn )
776
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
777
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
778
+ return module
760
779
761
780
762
781
@dataclass
@@ -1379,9 +1398,10 @@ def _float8_static_activation_float8_weight_transform(
1379
1398
return module
1380
1399
1381
1400
1382
- def uintx_weight_only (dtype , group_size = 64 , pack_dim = - 1 , use_hqq = False ):
1401
+ @dataclass
1402
+ class UIntXWeightOnlyConfig (AOBaseConfig ):
1383
1403
"""
1384
- Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1404
+ Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1385
1405
x is the number of bits specified by `dtype`
1386
1406
1387
1407
Args:
@@ -1391,6 +1411,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1391
1411
`pack_dim`: the dimension we use for packing, defaults to -1
1392
1412
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
1393
1413
"""
1414
+
1415
+ dtype : torch .dtype
1416
+ group_size : int = 64
1417
+ pack_dim : int = - 1
1418
+ use_hqq : bool = False
1419
+
1420
+
1421
+ # for BC
1422
+ uintx_weight_only = UIntXWeightOnlyConfig
1423
+
1424
+
1425
+ @register_quantize_module_handler (UIntXWeightOnlyConfig )
1426
+ def _uintx_weight_only_transform (
1427
+ module : torch .nn .Module , config : UIntXWeightOnlyConfig
1428
+ ):
1429
+ dtype = config .dtype
1430
+ group_size = config .group_size
1431
+ pack_dim = config .pack_dim
1432
+ use_hqq = config .use_hqq
1433
+
1434
+ weight = module .weight
1435
+
1394
1436
from torchao .quantization .quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
1395
1437
1396
1438
SUPPORTED_DTYPES = {
@@ -1405,49 +1447,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1405
1447
}
1406
1448
assert dtype in SUPPORTED_DTYPES , f"Unsupported dtype for hqq: { dtype } "
1407
1449
1408
- def apply_uintx_weight_only_quant (weight , dtype ):
1409
- mapping_type = MappingType .ASYMMETRIC
1410
- block_size = (1 , group_size )
1411
-
1412
- if use_hqq :
1413
- if dtype == torch .uint4 :
1414
- logger .warn (
1415
- "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1416
- )
1417
- quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [dtype ]
1418
- dtype = torch .uint8
1419
- eps = None
1420
- zero_point_dtype = None
1421
- zero_point_domain = ZeroPointDomain .FLOAT
1422
- preserve_zero = False
1423
- _layout = PlainLayout ()
1424
- else :
1425
- quant_min , quant_max = None , None
1426
- eps = torch .finfo (torch .float32 ).eps
1427
- zero_point_dtype = torch .int32
1428
- zero_point_domain = ZeroPointDomain .INT
1429
- preserve_zero = True
1430
- _layout = UintxLayout (dtype = dtype , pack_dim = pack_dim )
1450
+ mapping_type = MappingType .ASYMMETRIC
1451
+ block_size = (1 , group_size )
1431
1452
1432
- return to_affine_quantized_intx (
1433
- weight ,
1434
- mapping_type ,
1435
- block_size ,
1436
- dtype ,
1437
- quant_min = quant_min ,
1438
- quant_max = quant_max ,
1439
- eps = eps ,
1440
- zero_point_dtype = zero_point_dtype ,
1441
- zero_point_domain = zero_point_domain ,
1442
- preserve_zero = preserve_zero ,
1443
- _layout = _layout ,
1444
- use_hqq = use_hqq ,
1445
- )
1453
+ if use_hqq :
1454
+ if dtype == torch .uint4 :
1455
+ logger .warn (
1456
+ "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1457
+ )
1458
+ quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [dtype ]
1459
+ dtype = torch .uint8
1460
+ eps = None
1461
+ zero_point_dtype = None
1462
+ zero_point_domain = ZeroPointDomain .FLOAT
1463
+ preserve_zero = False
1464
+ _layout = PlainLayout ()
1465
+ else :
1466
+ quant_min , quant_max = None , None
1467
+ eps = torch .finfo (torch .float32 ).eps
1468
+ zero_point_dtype = torch .int32
1469
+ zero_point_domain = ZeroPointDomain .INT
1470
+ preserve_zero = True
1471
+ _layout = UintxLayout (dtype = dtype , pack_dim = pack_dim )
1446
1472
1447
- return _get_linear_subclass_inserter (apply_uintx_weight_only_quant , dtype = dtype )
1473
+ new_weight = to_affine_quantized_intx (
1474
+ weight ,
1475
+ mapping_type ,
1476
+ block_size ,
1477
+ dtype ,
1478
+ quant_min = quant_min ,
1479
+ quant_max = quant_max ,
1480
+ eps = eps ,
1481
+ zero_point_dtype = zero_point_dtype ,
1482
+ zero_point_domain = zero_point_domain ,
1483
+ preserve_zero = preserve_zero ,
1484
+ _layout = _layout ,
1485
+ use_hqq = use_hqq ,
1486
+ )
1487
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1488
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1489
+ return module
1448
1490
1449
1491
1450
- def fpx_weight_only (ebits : int , mbits : int ):
1492
+ @dataclass
1493
+ class FPXWeightOnlyConfig (AOBaseConfig ):
1451
1494
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
1452
1495
e.g. fp6_e3_m2, fp6_e2_m3, ...
1453
1496
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
@@ -1458,26 +1501,40 @@ def fpx_weight_only(ebits: int, mbits: int):
1458
1501
in the future
1459
1502
"""
1460
1503
1461
- def apply_quant_llm (weight : torch .Tensor ) -> torch .Tensor :
1462
- from torchao .dtypes import to_affine_quantized_fpx
1463
- from torchao .dtypes .floatx import FloatxTensorCoreLayout
1504
+ ebits : int
1505
+ mbits : int
1464
1506
1465
- assert (
1466
- weight .dim () == 2
1467
- ), f"floatx only works for 2-d Tensor, got: { weight .dim ()} "
1468
- out_dim , in_dim = weight .shape
1469
- if (in_dim % 64 != 0 ) or (out_dim % 256 != 0 ):
1470
- logger .info (
1471
- f"Skipping floatx quantization float{ ebits + mbits + 1 } _{ ebits } _{ mbits } because "
1472
- f"the shape is not compatible with the kernel: in_dim={ in_dim } , out_dim={ out_dim } "
1473
- "expected in_dim % 64 == 0 and out_dim % 256 == 0"
1474
- )
1475
- return weight
1476
1507
1477
- _layout = FloatxTensorCoreLayout (ebits , mbits )
1478
- return to_affine_quantized_fpx (weight , _layout )
1508
+ # for BC
1509
+ fpx_weight_only = FPXWeightOnlyConfig
1510
+
1511
+
1512
+ @register_quantize_module_handler (FPXWeightOnlyConfig )
1513
+ def _fpx_weight_only_transform (
1514
+ module : torch .nn .Module , config : FPXWeightOnlyConfig
1515
+ ) -> torch .nn .Module :
1516
+ ebits = config .ebits
1517
+ mbits = config .mbits
1518
+ weight = module .weight
1519
+
1520
+ from torchao .dtypes import to_affine_quantized_fpx
1521
+ from torchao .dtypes .floatx import FloatxTensorCoreLayout
1479
1522
1480
- return _get_linear_subclass_inserter (apply_quant_llm )
1523
+ assert weight .dim () == 2 , f"floatx only works for 2-d Tensor, got: { weight .dim ()} "
1524
+ out_dim , in_dim = weight .shape
1525
+ if (in_dim % 64 != 0 ) or (out_dim % 256 != 0 ):
1526
+ logger .info (
1527
+ f"Skipping floatx quantization float{ ebits + mbits + 1 } _{ ebits } _{ mbits } because "
1528
+ f"the shape is not compatible with the kernel: in_dim={ in_dim } , out_dim={ out_dim } "
1529
+ "expected in_dim % 64 == 0 and out_dim % 256 == 0"
1530
+ )
1531
+ return module
1532
+
1533
+ _layout = FloatxTensorCoreLayout (ebits , mbits )
1534
+ new_weight = to_affine_quantized_fpx (weight , _layout )
1535
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1536
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1537
+ return module
1481
1538
1482
1539
1483
1540
if TORCH_VERSION_AT_LEAST_2_5 :
0 commit comments