Skip to content

Commit e10e222

Browse files
committed
config migration: float*
Summary: TODO write me Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 26d4a037f251363bb27638078134920d279df1a9 ghstack-comment-id: 2649492752 Pull Request resolved: #1694
1 parent b678cc6 commit e10e222

File tree

3 files changed

+180
-105
lines changed

3 files changed

+180
-105
lines changed

test/dtypes/test_affine_quantized.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
123123
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
124124
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
125125
def test_to_device(self, apply_quant):
126+
def _apply(module, config_or_subclass_inserter):
127+
if isinstance(config_or_subclass_inserter, AOBaseConfig):
128+
quantize_(module, config_or_subclass_inserter)
129+
else:
130+
# TODO(#1690): delete this once config migration is done
131+
module = config_or_subclass_inserter(module)
132+
return module
133+
126134
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
127-
ql = apply_quant(linear)
135+
ql = _apply(linear, apply_quant)
128136
ql.to("cuda")
129137

130138
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
131-
ql = apply_quant(linear)
139+
ql = _apply(linear, apply_quant)
132140
ql.to(device="cuda")
133141

134142
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
135-
ql = apply_quant(linear)
143+
ql = _apply(linear, apply_quant)
136144
ql.cuda()
137145

138146
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

test/quantization/test_quant_api.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
Quantizer,
3131
TwoStepQuantizer,
3232
_replace_with_custom_fn_if_matches_filter,
33+
float8_dynamic_activation_float8_weight,
34+
float8_static_activation_float8_weight,
35+
float8_weight_only,
3336
int4_weight_only,
3437
int8_dynamic_activation_int4_weight,
3538
int8_dynamic_activation_int8_weight,
@@ -784,9 +787,21 @@ def test_int4wo_cpu(self, dtype, x_dim):
784787
assert "_weight_int4pack_mm_for_cpu" in code[0]
785788
assert "aten.mm.default" not in code[0]
786789

790+
# TODO(#1690): move to new config names
787791
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
788792
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
789-
def test_int4_weight_only_numerics(self):
793+
@common_utils.parametrize(
794+
"config",
795+
[
796+
int4_weight_only(),
797+
float8_weight_only(),
798+
float8_dynamic_activation_float8_weight(),
799+
float8_static_activation_float8_weight(
800+
scale=torch.tensor([1.0], device="cuda")
801+
),
802+
],
803+
)
804+
def test_workflow_e2e_numerics(self, config):
790805
"""
791806
Simple test of e2e int4_weight_only workflow, comparing numerics
792807
to a bfloat16 baseline.
@@ -796,16 +811,16 @@ def test_int4_weight_only_numerics(self):
796811
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
797812
# is that expected?
798813
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
799-
m_int4_wo = copy.deepcopy(m_ref)
814+
m_q = copy.deepcopy(m_ref)
800815

801816
# quantize
802-
quantize_(m_int4_wo, int4_weight_only())
817+
quantize_(m_q, config)
803818

804819
with torch.no_grad():
805820
y_ref = m_ref(x)
806-
y_int4_wo = m_int4_wo(x)
821+
y_q = m_q(x)
807822

808-
sqnr = compute_error(y_ref, y_int4_wo)
823+
sqnr = compute_error(y_ref, y_q)
809824
assert sqnr >= 20, f"SQNR {sqnr} is too low"
810825

811826

torchao/quantization/quant_api.py

+149-97
Original file line numberDiff line numberDiff line change
@@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
10301030
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
10311031

10321032

1033-
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
1033+
@dataclass
1034+
class Float8WeightOnlyConfig(AOBaseConfig):
10341035
"""
1035-
Applies float8 weight-only symmetric per-channel quantization to linear layers.
1036+
Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
10361037
10371038
Args:
10381039
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
10391040
10401041
Note:
10411042
The actual matmul will be computed in original precision of the weight tensor.
1042-
10431043
"""
1044-
from torchao.dtypes import to_affine_quantized_floatx
10451044

1046-
def apply_float8wo_quant(weight):
1047-
block_size = (1, weight.shape[1])
1048-
return to_affine_quantized_floatx(
1049-
input_float=weight,
1050-
block_size=block_size,
1051-
target_dtype=weight_dtype,
1052-
scale_dtype=None,
1053-
_layout=Float8Layout(mm_config=None),
1054-
)
1045+
weight_dtype: torch.dtype = torch.float8_e4m3fn
1046+
1047+
1048+
# for BC
1049+
float8_weight_only = Float8WeightOnlyConfig
1050+
1051+
1052+
@register_quantize_module_handler(Float8WeightOnlyConfig)
1053+
def _float8_weight_only_transform(
1054+
module: torch.nn.Module, config: Float8WeightOnlyConfig
1055+
) -> torch.nn.Module:
1056+
from torchao.dtypes import to_affine_quantized_floatx
10551057

1056-
return _get_linear_subclass_inserter(apply_float8wo_quant)
1058+
weight = module.weight
1059+
block_size = (1, weight.shape[1])
1060+
new_weight = to_affine_quantized_floatx(
1061+
input_float=weight,
1062+
block_size=block_size,
1063+
target_dtype=config.weight_dtype,
1064+
scale_dtype=None,
1065+
_layout=Float8Layout(mm_config=None),
1066+
)
1067+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1068+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1069+
return module
10571070

10581071

10591072
_fp8_granularities = Union[PerTensor, PerRow]
@@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
11701183
return is_compatible
11711184

11721185

1173-
def float8_dynamic_activation_float8_weight(
1174-
activation_dtype: torch.dtype = torch.float8_e4m3fn,
1175-
weight_dtype: torch.dtype = torch.float8_e4m3fn,
1176-
granularity: Optional[
1177-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1178-
] = None,
1179-
mm_config: Optional[Float8MMConfig] = None,
1180-
):
1186+
@dataclass
1187+
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
11811188
"""
1182-
Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
1189+
Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
11831190
11841191
Args:
11851192
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
@@ -1192,104 +1199,149 @@ def float8_dynamic_activation_float8_weight(
11921199
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
11931200
11941201
"""
1195-
assert (
1196-
is_sm_at_least_89() or is_MI300()
1197-
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1198-
if mm_config is None:
1199-
mm_config = Float8MMConfig(use_fast_accum=True)
12001202

1201-
activation_granularity, weight_granularity = _normalize_granularity(granularity)
1203+
activation_dtype: torch.dtype = torch.float8_e4m3fn
1204+
weight_dtype: torch.dtype = torch.float8_e4m3fn
1205+
granularity: Optional[
1206+
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1207+
] = None
1208+
mm_config: Optional[Float8MMConfig] = None
12021209

1203-
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
1204-
if not _fp8_mm_compat(weight):
1205-
return weight
1206-
if isinstance(weight_granularity, PerRow):
1207-
assert (
1208-
weight.dtype == torch.bfloat16
1209-
), "PerRow quantization only works for bfloat16 precision input weight"
1210+
def __post_init__(self):
1211+
assert (
1212+
is_sm_at_least_89() or is_MI300()
1213+
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1214+
if self.mm_config is None:
1215+
self.mm_config = Float8MMConfig(use_fast_accum=True)
12101216

1211-
block_size = get_block_size(weight.shape, weight_granularity)
1212-
quantized_weight = to_affine_quantized_floatx(
1213-
input_float=weight,
1214-
block_size=block_size,
1215-
target_dtype=weight_dtype,
1216-
scale_dtype=torch.float32,
1217-
_layout=Float8Layout(mm_config=mm_config),
1218-
)
12191217

1220-
input_quant_func = _input_activation_quant_func_fp8
1221-
input_quant_kwargs = {
1222-
"activation_granularity": activation_granularity,
1223-
"activation_dtype": activation_dtype,
1224-
}
1218+
# for bc
1219+
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
12251220

1226-
quantized_weight = to_linear_activation_quantized(
1227-
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
1228-
)
1229-
return quantized_weight
12301221

1231-
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
1222+
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)
1223+
def _float8_dynamic_activation_float8_weight_transform(
1224+
module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
1225+
):
1226+
activation_dtype = config.activation_dtype
1227+
weight_dtype = config.weight_dtype
1228+
granularity = config.granularity
1229+
mm_config = config.mm_config
1230+
weight = module.weight
12321231

1232+
activation_granularity, weight_granularity = _normalize_granularity(granularity)
12331233

1234-
def float8_static_activation_float8_weight(
1235-
scale: torch.Tensor,
1236-
activation_dtype: torch.dtype = torch.float8_e4m3fn,
1237-
weight_dtype: torch.dtype = torch.float8_e4m3fn,
1238-
granularity: Optional[
1239-
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1240-
] = None,
1241-
mm_config: Optional[Float8MMConfig] = None,
1242-
):
1234+
if not _fp8_mm_compat(weight):
1235+
# TODO(future PR): this should really throw an exception instead of silently
1236+
# not doing what the user asked
1237+
return module
1238+
if isinstance(weight_granularity, PerRow):
1239+
assert (
1240+
weight.dtype == torch.bfloat16
1241+
), "PerRow quantization only works for bfloat16 precision input weight"
1242+
1243+
block_size = get_block_size(weight.shape, weight_granularity)
1244+
quantized_weight = to_affine_quantized_floatx(
1245+
input_float=weight,
1246+
block_size=block_size,
1247+
target_dtype=weight_dtype,
1248+
scale_dtype=torch.float32,
1249+
_layout=Float8Layout(mm_config=mm_config),
1250+
)
1251+
1252+
input_quant_func = _input_activation_quant_func_fp8
1253+
input_quant_kwargs = {
1254+
"activation_granularity": activation_granularity,
1255+
"activation_dtype": activation_dtype,
1256+
}
1257+
1258+
quantized_weight = to_linear_activation_quantized(
1259+
quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
1260+
)
1261+
1262+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
1263+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1264+
return module
1265+
1266+
1267+
@dataclass
1268+
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
12431269
"""
1244-
Applies float8 static symmetric quantization to
1270+
Configuration for applying float8 static symmetric quantization to
12451271
12461272
Args:
12471273
scale (torch.Tensor): The scale tensor for activation quantization.
12481274
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
12491275
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
12501276
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
12511277
"""
1252-
assert (
1253-
is_sm_at_least_89() or is_MI300()
1254-
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
1255-
if mm_config is None:
1256-
mm_config = Float8MMConfig(use_fast_accum=True)
12571278

1279+
scale: torch.Tensor
1280+
activation_dtype: torch.dtype = torch.float8_e4m3fn
1281+
weight_dtype: torch.dtype = torch.float8_e4m3fn
1282+
granularity: Optional[
1283+
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
1284+
] = None
1285+
mm_config: Optional[Float8MMConfig] = None
1286+
1287+
def __post_init__(self):
1288+
assert (
1289+
is_sm_at_least_89() or is_MI300()
1290+
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
1291+
if self.mm_config is None:
1292+
self.mm_config = Float8MMConfig(use_fast_accum=True)
1293+
1294+
1295+
# for bc
1296+
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
1297+
1298+
1299+
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
1300+
def _float8_static_activation_float8_weight_transform(
1301+
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
1302+
):
1303+
scale = config.scale
1304+
activation_dtype = config.activation_dtype
1305+
weight_dtype = config.weight_dtype
1306+
granularity = config.granularity
1307+
mm_config = config.mm_config
1308+
1309+
weight = module.weight
12581310
activation_granularity, weight_granularity = _normalize_granularity(granularity)
12591311
assert isinstance(
12601312
activation_granularity, PerTensor
12611313
), "Static quantization only supports PerTensor granularity"
12621314

1263-
def apply_float8_static_activation_quant(weight: torch.Tensor):
1264-
if not _fp8_mm_compat(weight):
1265-
return weight
1266-
block_size = get_block_size(weight.shape, weight_granularity)
1267-
quantized_weight = to_affine_quantized_floatx(
1268-
input_float=weight,
1269-
block_size=block_size,
1270-
target_dtype=weight_dtype,
1271-
scale_dtype=torch.float32,
1272-
_layout=Float8Layout(mm_config=mm_config),
1273-
)
1315+
if not _fp8_mm_compat(weight):
1316+
# TODO(future PR): this should really throw an exception instead of silently
1317+
# not doing what the user asked
1318+
return module
1319+
block_size = get_block_size(weight.shape, weight_granularity)
1320+
quantized_weight = to_affine_quantized_floatx(
1321+
input_float=weight,
1322+
block_size=block_size,
1323+
target_dtype=weight_dtype,
1324+
scale_dtype=torch.float32,
1325+
_layout=Float8Layout(mm_config=mm_config),
1326+
)
12741327

1275-
input_quant_func = _input_activation_quant_func_fp8
1276-
input_quant_kwargs = {
1277-
"activation_granularity": activation_granularity,
1278-
"activation_dtype": activation_dtype,
1279-
}
1280-
1281-
quantized_weight = (
1282-
to_weight_tensor_with_linear_activation_quantization_metadata(
1283-
quantized_weight,
1284-
input_quant_func,
1285-
scale=scale,
1286-
zero_point=None,
1287-
quant_kwargs=input_quant_kwargs,
1288-
)
1289-
)
1290-
return quantized_weight
1328+
input_quant_func = _input_activation_quant_func_fp8
1329+
input_quant_kwargs = {
1330+
"activation_granularity": activation_granularity,
1331+
"activation_dtype": activation_dtype,
1332+
}
12911333

1292-
return _get_linear_subclass_inserter(apply_float8_static_activation_quant)
1334+
quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
1335+
quantized_weight,
1336+
input_quant_func,
1337+
scale=scale,
1338+
zero_point=None,
1339+
quant_kwargs=input_quant_kwargs,
1340+
)
1341+
1342+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
1343+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1344+
return module
12931345

12941346

12951347
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):

0 commit comments

Comments
 (0)