@@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform(
729
729
return module
730
730
731
731
732
- def gemlite_uintx_weight_only (
733
- group_size : Optional [int ] = 64 ,
734
- bit_width : int = 4 ,
735
- packing_bitwidth : int = 32 ,
736
- contiguous : Optional [bool ] = None ,
737
- ):
732
+ @dataclass
733
+ class GemliteUIntXWeightOnlyConfig (AOBaseConfig ):
738
734
"""
739
735
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
740
736
This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric.
@@ -747,16 +743,39 @@ def gemlite_uintx_weight_only(
747
743
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
748
744
"""
749
745
746
+ group_size : Optional [int ] = 64
747
+ bit_width : int = 4
748
+ packing_bitwidth : int = 32
749
+ contiguous : Optional [bool ] = None
750
+
751
+
752
+ # for BC
753
+ gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
754
+
755
+
756
+ @register_quantize_module_handler (GemliteUIntXWeightOnlyConfig )
757
+ def _gemlite_uintx_weight_only_transform (
758
+ module : torch .nn .Module , config : GemliteUIntXWeightOnlyConfig
759
+ ):
760
+ group_size = config .group_size
761
+ bit_width = config .bit_width
762
+ packing_bitwidth = config .packing_bitwidth
763
+ contiguous = config .contiguous
764
+
765
+ weight = module .weight
766
+
750
767
from torchao .dtypes .uintx .gemlite_layout import get_gemlite_aqt_kwargs
751
768
752
769
use_hqq = True if bit_width == 4 else False
753
- apply_fn = lambda weight : to_affine_quantized_intx (
770
+ new_weight = to_affine_quantized_intx (
754
771
weight ,
755
772
** get_gemlite_aqt_kwargs (
756
773
weight , group_size , bit_width , packing_bitwidth , contiguous , use_hqq
757
774
),
758
775
)
759
- return _get_linear_subclass_inserter (apply_fn )
776
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
777
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
778
+ return module
760
779
761
780
762
781
@dataclass
@@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform(
1380
1399
return module
1381
1400
1382
1401
1383
- def uintx_weight_only (dtype , group_size = 64 , pack_dim = - 1 , use_hqq = False ):
1402
+ @dataclass
1403
+ class UIntXWeightOnlyConfig (AOBaseConfig ):
1384
1404
"""
1385
- Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1405
+ Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
1386
1406
x is the number of bits specified by `dtype`
1387
1407
1388
1408
Args:
@@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1392
1412
`pack_dim`: the dimension we use for packing, defaults to -1
1393
1413
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
1394
1414
"""
1415
+
1416
+ dtype : torch .dtype
1417
+ group_size : int = 64
1418
+ pack_dim : int = - 1
1419
+ use_hqq : bool = False
1420
+
1421
+
1422
+ # for BC
1423
+ uintx_weight_only = UIntXWeightOnlyConfig
1424
+
1425
+
1426
+ @register_quantize_module_handler (UIntXWeightOnlyConfig )
1427
+ def _uintx_weight_only_transform (
1428
+ module : torch .nn .Module , config : UIntXWeightOnlyConfig
1429
+ ):
1430
+ dtype = config .dtype
1431
+ group_size = config .group_size
1432
+ pack_dim = config .pack_dim
1433
+ use_hqq = config .use_hqq
1434
+
1435
+ weight = module .weight
1436
+
1395
1437
from torchao .quantization .quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
1396
1438
1397
1439
SUPPORTED_DTYPES = {
@@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
1406
1448
}
1407
1449
assert dtype in SUPPORTED_DTYPES , f"Unsupported dtype for hqq: { dtype } "
1408
1450
1409
- def apply_uintx_weight_only_quant (weight , dtype ):
1410
- mapping_type = MappingType .ASYMMETRIC
1411
- block_size = (1 , group_size )
1412
-
1413
- if use_hqq :
1414
- if dtype == torch .uint4 :
1415
- logger .warn (
1416
- "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1417
- )
1418
- quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [dtype ]
1419
- dtype = torch .uint8
1420
- eps = None
1421
- zero_point_dtype = None
1422
- zero_point_domain = ZeroPointDomain .FLOAT
1423
- preserve_zero = False
1424
- _layout = PlainLayout ()
1425
- else :
1426
- quant_min , quant_max = None , None
1427
- eps = torch .finfo (torch .float32 ).eps
1428
- zero_point_dtype = torch .int32
1429
- zero_point_domain = ZeroPointDomain .INT
1430
- preserve_zero = True
1431
- _layout = UintxLayout (dtype = dtype , pack_dim = pack_dim )
1451
+ mapping_type = MappingType .ASYMMETRIC
1452
+ block_size = (1 , group_size )
1432
1453
1433
- return to_affine_quantized_intx (
1434
- weight ,
1435
- mapping_type ,
1436
- block_size ,
1437
- dtype ,
1438
- quant_min = quant_min ,
1439
- quant_max = quant_max ,
1440
- eps = eps ,
1441
- zero_point_dtype = zero_point_dtype ,
1442
- zero_point_domain = zero_point_domain ,
1443
- preserve_zero = preserve_zero ,
1444
- _layout = _layout ,
1445
- use_hqq = use_hqq ,
1446
- )
1454
+ if use_hqq :
1455
+ if dtype == torch .uint4 :
1456
+ logger .warn (
1457
+ "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance"
1458
+ )
1459
+ quant_min , quant_max = _DTYPE_TO_QVALUE_BOUNDS [dtype ]
1460
+ dtype = torch .uint8
1461
+ eps = None
1462
+ zero_point_dtype = None
1463
+ zero_point_domain = ZeroPointDomain .FLOAT
1464
+ preserve_zero = False
1465
+ _layout = PlainLayout ()
1466
+ else :
1467
+ quant_min , quant_max = None , None
1468
+ eps = torch .finfo (torch .float32 ).eps
1469
+ zero_point_dtype = torch .int32
1470
+ zero_point_domain = ZeroPointDomain .INT
1471
+ preserve_zero = True
1472
+ _layout = UintxLayout (dtype = dtype , pack_dim = pack_dim )
1447
1473
1448
- return _get_linear_subclass_inserter (apply_uintx_weight_only_quant , dtype = dtype )
1474
+ new_weight = to_affine_quantized_intx (
1475
+ weight ,
1476
+ mapping_type ,
1477
+ block_size ,
1478
+ dtype ,
1479
+ quant_min = quant_min ,
1480
+ quant_max = quant_max ,
1481
+ eps = eps ,
1482
+ zero_point_dtype = zero_point_dtype ,
1483
+ zero_point_domain = zero_point_domain ,
1484
+ preserve_zero = preserve_zero ,
1485
+ _layout = _layout ,
1486
+ use_hqq = use_hqq ,
1487
+ )
1488
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1489
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1490
+ return module
1449
1491
1450
1492
1451
- def fpx_weight_only (ebits : int , mbits : int ):
1493
+ @dataclass
1494
+ class FPXWeightOnlyConfig (AOBaseConfig ):
1452
1495
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
1453
1496
e.g. fp6_e3_m2, fp6_e2_m3, ...
1454
1497
The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112
@@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int):
1459
1502
in the future
1460
1503
"""
1461
1504
1462
- def apply_quant_llm (weight : torch .Tensor ) -> torch .Tensor :
1463
- from torchao .dtypes import to_affine_quantized_fpx
1464
- from torchao .dtypes .floatx import FloatxTensorCoreLayout
1505
+ ebits : int
1506
+ mbits : int
1465
1507
1466
- assert (
1467
- weight .dim () == 2
1468
- ), f"floatx only works for 2-d Tensor, got: { weight .dim ()} "
1469
- out_dim , in_dim = weight .shape
1470
- if (in_dim % 64 != 0 ) or (out_dim % 256 != 0 ):
1471
- logger .info (
1472
- f"Skipping floatx quantization float{ ebits + mbits + 1 } _{ ebits } _{ mbits } because "
1473
- f"the shape is not compatible with the kernel: in_dim={ in_dim } , out_dim={ out_dim } "
1474
- "expected in_dim % 64 == 0 and out_dim % 256 == 0"
1475
- )
1476
- return weight
1477
1508
1478
- _layout = FloatxTensorCoreLayout (ebits , mbits )
1479
- return to_affine_quantized_fpx (weight , _layout )
1509
+ # for BC
1510
+ fpx_weight_only = FPXWeightOnlyConfig
1511
+
1512
+
1513
+ @register_quantize_module_handler (FPXWeightOnlyConfig )
1514
+ def _fpx_weight_only_transform (
1515
+ module : torch .nn .Module , config : FPXWeightOnlyConfig
1516
+ ) -> torch .nn .Module :
1517
+ ebits = config .ebits
1518
+ mbits = config .mbits
1519
+ weight = module .weight
1520
+
1521
+ from torchao .dtypes import to_affine_quantized_fpx
1522
+ from torchao .dtypes .floatx import FloatxTensorCoreLayout
1480
1523
1481
- return _get_linear_subclass_inserter (apply_quant_llm )
1524
+ assert weight .dim () == 2 , f"floatx only works for 2-d Tensor, got: { weight .dim ()} "
1525
+ out_dim , in_dim = weight .shape
1526
+ if (in_dim % 64 != 0 ) or (out_dim % 256 != 0 ):
1527
+ logger .info (
1528
+ f"Skipping floatx quantization float{ ebits + mbits + 1 } _{ ebits } _{ mbits } because "
1529
+ f"the shape is not compatible with the kernel: in_dim={ in_dim } , out_dim={ out_dim } "
1530
+ "expected in_dim % 64 == 0 and out_dim % 256 == 0"
1531
+ )
1532
+ return module
1533
+
1534
+ _layout = FloatxTensorCoreLayout (ebits , mbits )
1535
+ new_weight = to_affine_quantized_fpx (weight , _layout )
1536
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1537
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1538
+ return module
1482
1539
1483
1540
1484
1541
if TORCH_VERSION_AT_LEAST_2_5 :
0 commit comments