@@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
1030
1030
return int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
1031
1031
1032
1032
1033
- def float8_weight_only (weight_dtype : torch .dtype = torch .float8_e4m3fn ):
1033
+ @dataclass
1034
+ class Float8WeightOnlyConfig (AOBaseConfig ):
1034
1035
"""
1035
- Applies float8 weight-only symmetric per-channel quantization to linear layers.
1036
+ Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
1036
1037
1037
1038
Args:
1038
1039
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
1039
1040
1040
1041
Note:
1041
1042
The actual matmul will be computed in original precision of the weight tensor.
1042
-
1043
1043
"""
1044
- from torchao .dtypes import to_affine_quantized_floatx
1045
1044
1046
- def apply_float8wo_quant (weight ):
1047
- block_size = (1 , weight .shape [1 ])
1048
- return to_affine_quantized_floatx (
1049
- input_float = weight ,
1050
- block_size = block_size ,
1051
- target_dtype = weight_dtype ,
1052
- scale_dtype = None ,
1053
- _layout = Float8Layout (mm_config = None ),
1054
- )
1045
+ weight_dtype : torch .dtype = torch .float8_e4m3fn
1046
+
1047
+
1048
+ # for BC
1049
+ float8_weight_only = Float8WeightOnlyConfig
1050
+
1051
+
1052
+ @register_quantize_module_handler (Float8WeightOnlyConfig )
1053
+ def _float8_weight_only_transform (
1054
+ module : torch .nn .Module , config : Float8WeightOnlyConfig
1055
+ ) -> torch .nn .Module :
1056
+ from torchao .dtypes import to_affine_quantized_floatx
1055
1057
1056
- return _get_linear_subclass_inserter (apply_float8wo_quant )
1058
+ weight = module .weight
1059
+ block_size = (1 , weight .shape [1 ])
1060
+ new_weight = to_affine_quantized_floatx (
1061
+ input_float = weight ,
1062
+ block_size = block_size ,
1063
+ target_dtype = config .weight_dtype ,
1064
+ scale_dtype = None ,
1065
+ _layout = Float8Layout (mm_config = None ),
1066
+ )
1067
+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1068
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1069
+ return module
1057
1070
1058
1071
1059
1072
_fp8_granularities = Union [PerTensor , PerRow ]
@@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
1170
1183
return is_compatible
1171
1184
1172
1185
1173
- def float8_dynamic_activation_float8_weight (
1174
- activation_dtype : torch .dtype = torch .float8_e4m3fn ,
1175
- weight_dtype : torch .dtype = torch .float8_e4m3fn ,
1176
- granularity : Optional [
1177
- Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1178
- ] = None ,
1179
- mm_config : Optional [Float8MMConfig ] = None ,
1180
- ):
1186
+ @dataclass
1187
+ class Float8DynamicActivationFloat8WeightConfig (AOBaseConfig ):
1181
1188
"""
1182
- Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
1189
+ Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
1183
1190
1184
1191
Args:
1185
1192
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
@@ -1192,104 +1199,149 @@ def float8_dynamic_activation_float8_weight(
1192
1199
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1193
1200
1194
1201
"""
1195
- assert (
1196
- is_sm_at_least_89 () or is_MI300 ()
1197
- ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1198
- if mm_config is None :
1199
- mm_config = Float8MMConfig (use_fast_accum = True )
1200
1202
1201
- activation_granularity , weight_granularity = _normalize_granularity (granularity )
1203
+ activation_dtype : torch .dtype = torch .float8_e4m3fn
1204
+ weight_dtype : torch .dtype = torch .float8_e4m3fn
1205
+ granularity : Optional [
1206
+ Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1207
+ ] = None
1208
+ mm_config : Optional [Float8MMConfig ] = None
1202
1209
1203
- def apply_float8_dynamic_activation_quant (weight : torch .Tensor ):
1204
- if not _fp8_mm_compat (weight ):
1205
- return weight
1206
- if isinstance (weight_granularity , PerRow ):
1207
- assert (
1208
- weight .dtype == torch .bfloat16
1209
- ), "PerRow quantization only works for bfloat16 precision input weight"
1210
+ def __post_init__ (self ):
1211
+ assert (
1212
+ is_sm_at_least_89 () or is_MI300 ()
1213
+ ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1214
+ if self .mm_config is None :
1215
+ self .mm_config = Float8MMConfig (use_fast_accum = True )
1210
1216
1211
- block_size = get_block_size (weight .shape , weight_granularity )
1212
- quantized_weight = to_affine_quantized_floatx (
1213
- input_float = weight ,
1214
- block_size = block_size ,
1215
- target_dtype = weight_dtype ,
1216
- scale_dtype = torch .float32 ,
1217
- _layout = Float8Layout (mm_config = mm_config ),
1218
- )
1219
1217
1220
- input_quant_func = _input_activation_quant_func_fp8
1221
- input_quant_kwargs = {
1222
- "activation_granularity" : activation_granularity ,
1223
- "activation_dtype" : activation_dtype ,
1224
- }
1218
+ # for bc
1219
+ float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
1225
1220
1226
- quantized_weight = to_linear_activation_quantized (
1227
- quantized_weight , input_quant_func , quant_kwargs = input_quant_kwargs
1228
- )
1229
- return quantized_weight
1230
1221
1231
- return _get_linear_subclass_inserter (apply_float8_dynamic_activation_quant )
1222
+ @register_quantize_module_handler (Float8DynamicActivationFloat8WeightConfig )
1223
+ def _float8_dynamic_activation_float8_weight_transform (
1224
+ module : torch .nn .Module , config : Float8DynamicActivationFloat8WeightConfig
1225
+ ):
1226
+ activation_dtype = config .activation_dtype
1227
+ weight_dtype = config .weight_dtype
1228
+ granularity = config .granularity
1229
+ mm_config = config .mm_config
1230
+ weight = module .weight
1232
1231
1232
+ activation_granularity , weight_granularity = _normalize_granularity (granularity )
1233
1233
1234
- def float8_static_activation_float8_weight (
1235
- scale : torch .Tensor ,
1236
- activation_dtype : torch .dtype = torch .float8_e4m3fn ,
1237
- weight_dtype : torch .dtype = torch .float8_e4m3fn ,
1238
- granularity : Optional [
1239
- Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1240
- ] = None ,
1241
- mm_config : Optional [Float8MMConfig ] = None ,
1242
- ):
1234
+ if not _fp8_mm_compat (weight ):
1235
+ # TODO(future PR): this should really throw an exception instead of silently
1236
+ # not doing what the user asked
1237
+ return module
1238
+ if isinstance (weight_granularity , PerRow ):
1239
+ assert (
1240
+ weight .dtype == torch .bfloat16
1241
+ ), "PerRow quantization only works for bfloat16 precision input weight"
1242
+
1243
+ block_size = get_block_size (weight .shape , weight_granularity )
1244
+ quantized_weight = to_affine_quantized_floatx (
1245
+ input_float = weight ,
1246
+ block_size = block_size ,
1247
+ target_dtype = weight_dtype ,
1248
+ scale_dtype = torch .float32 ,
1249
+ _layout = Float8Layout (mm_config = mm_config ),
1250
+ )
1251
+
1252
+ input_quant_func = _input_activation_quant_func_fp8
1253
+ input_quant_kwargs = {
1254
+ "activation_granularity" : activation_granularity ,
1255
+ "activation_dtype" : activation_dtype ,
1256
+ }
1257
+
1258
+ quantized_weight = to_linear_activation_quantized (
1259
+ quantized_weight , input_quant_func , quant_kwargs = input_quant_kwargs
1260
+ )
1261
+
1262
+ module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
1263
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1264
+ return module
1265
+
1266
+
1267
+ @dataclass
1268
+ class Float8StaticActivationFloat8WeightConfig (AOBaseConfig ):
1243
1269
"""
1244
- Applies float8 static symmetric quantization to
1270
+ Configuration for applying float8 static symmetric quantization to
1245
1271
1246
1272
Args:
1247
1273
scale (torch.Tensor): The scale tensor for activation quantization.
1248
1274
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
1249
1275
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
1250
1276
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1251
1277
"""
1252
- assert (
1253
- is_sm_at_least_89 () or is_MI300 ()
1254
- ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
1255
- if mm_config is None :
1256
- mm_config = Float8MMConfig (use_fast_accum = True )
1257
1278
1279
+ scale : torch .Tensor
1280
+ activation_dtype : torch .dtype = torch .float8_e4m3fn
1281
+ weight_dtype : torch .dtype = torch .float8_e4m3fn
1282
+ granularity : Optional [
1283
+ Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1284
+ ] = None
1285
+ mm_config : Optional [Float8MMConfig ] = None
1286
+
1287
+ def __post_init__ (self ):
1288
+ assert (
1289
+ is_sm_at_least_89 () or is_MI300 ()
1290
+ ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
1291
+ if self .mm_config is None :
1292
+ self .mm_config = Float8MMConfig (use_fast_accum = True )
1293
+
1294
+
1295
+ # for bc
1296
+ float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
1297
+
1298
+
1299
+ @register_quantize_module_handler (Float8StaticActivationFloat8WeightConfig )
1300
+ def _float8_static_activation_float8_weight_transform (
1301
+ module : torch .nn .Module , config : Float8StaticActivationFloat8WeightConfig
1302
+ ):
1303
+ scale = config .scale
1304
+ activation_dtype = config .activation_dtype
1305
+ weight_dtype = config .weight_dtype
1306
+ granularity = config .granularity
1307
+ mm_config = config .mm_config
1308
+
1309
+ weight = module .weight
1258
1310
activation_granularity , weight_granularity = _normalize_granularity (granularity )
1259
1311
assert isinstance (
1260
1312
activation_granularity , PerTensor
1261
1313
), "Static quantization only supports PerTensor granularity"
1262
1314
1263
- def apply_float8_static_activation_quant (weight : torch .Tensor ):
1264
- if not _fp8_mm_compat (weight ):
1265
- return weight
1266
- block_size = get_block_size (weight .shape , weight_granularity )
1267
- quantized_weight = to_affine_quantized_floatx (
1268
- input_float = weight ,
1269
- block_size = block_size ,
1270
- target_dtype = weight_dtype ,
1271
- scale_dtype = torch .float32 ,
1272
- _layout = Float8Layout (mm_config = mm_config ),
1273
- )
1315
+ if not _fp8_mm_compat (weight ):
1316
+ # TODO(future PR): this should really throw an exception instead of silently
1317
+ # not doing what the user asked
1318
+ return module
1319
+ block_size = get_block_size (weight .shape , weight_granularity )
1320
+ quantized_weight = to_affine_quantized_floatx (
1321
+ input_float = weight ,
1322
+ block_size = block_size ,
1323
+ target_dtype = weight_dtype ,
1324
+ scale_dtype = torch .float32 ,
1325
+ _layout = Float8Layout (mm_config = mm_config ),
1326
+ )
1274
1327
1275
- input_quant_func = _input_activation_quant_func_fp8
1276
- input_quant_kwargs = {
1277
- "activation_granularity" : activation_granularity ,
1278
- "activation_dtype" : activation_dtype ,
1279
- }
1280
-
1281
- quantized_weight = (
1282
- to_weight_tensor_with_linear_activation_quantization_metadata (
1283
- quantized_weight ,
1284
- input_quant_func ,
1285
- scale = scale ,
1286
- zero_point = None ,
1287
- quant_kwargs = input_quant_kwargs ,
1288
- )
1289
- )
1290
- return quantized_weight
1328
+ input_quant_func = _input_activation_quant_func_fp8
1329
+ input_quant_kwargs = {
1330
+ "activation_granularity" : activation_granularity ,
1331
+ "activation_dtype" : activation_dtype ,
1332
+ }
1291
1333
1292
- return _get_linear_subclass_inserter (apply_float8_static_activation_quant )
1334
+ quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata (
1335
+ quantized_weight ,
1336
+ input_quant_func ,
1337
+ scale = scale ,
1338
+ zero_point = None ,
1339
+ quant_kwargs = input_quant_kwargs ,
1340
+ )
1341
+
1342
+ module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
1343
+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1344
+ return module
1293
1345
1294
1346
1295
1347
def uintx_weight_only (dtype , group_size = 64 , pack_dim = - 1 , use_hqq = False ):
0 commit comments