Skip to content

Commit c9b493f

Browse files
replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs
ghstack-source-id: f655d60 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
1 parent fa20ed1 commit c9b493f

File tree

7 files changed

+105
-131
lines changed

7 files changed

+105
-131
lines changed

docs/source/api_ref_dtypes.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ torchao.dtypes
1313
to_nf4
1414
to_affine_quantized_intx
1515
to_affine_quantized_intx_static
16-
to_affine_quantized_floatx
16+
to_affine_quantized_float8
1717
to_affine_quantized_floatx_static
1818
to_affine_quantized_fpx
1919
NF4Tensor

torchao/dtypes/__init__.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from . import affine_quantized_tensor_ops
22
from .affine_quantized_tensor import (
33
AffineQuantizedTensor,
4-
to_affine_quantized_floatx,
54
to_affine_quantized_floatx_static,
65
# experimental, will be merged into floatx in the future
76
to_affine_quantized_fpx,
87
to_affine_quantized_intx,
98
to_affine_quantized_intx_static,
109
)
11-
from .floatx import (
12-
Float8Layout,
13-
)
10+
from .float8 import to_affine_quantized_float8
11+
from .floatx import Float8Layout
1412
from .nf4tensor import NF4Tensor, to_nf4
1513
from .uintx import (
1614
BlockSparseLayout,
@@ -24,10 +22,7 @@
2422
UintxLayout,
2523
to_marlinqqq_quantized_intx,
2624
)
27-
from .utils import (
28-
Layout,
29-
PlainLayout,
30-
)
25+
from .utils import Layout, PlainLayout
3126

3227
__all__ = [
3328
"NF4Tensor",
@@ -36,8 +31,8 @@
3631
"to_affine_quantized_intx",
3732
"to_affine_quantized_intx_static",
3833
"to_affine_quantized_fpx",
39-
"to_affine_quantized_floatx",
4034
"to_affine_quantized_floatx_static",
35+
"to_affine_quantized_float8",
4136
"to_marlinqqq_quantized_intx",
4237
"Layout",
4338
"PlainLayout",

torchao/dtypes/affine_quantized_tensor.py

+22-70
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
MappingType,
1111
ZeroPointDomain,
1212
choose_qparams_affine,
13-
choose_qparams_affine_float8,
1413
choose_qparams_affine_floatx,
1514
choose_qparams_and_quantize_affine_hqq,
1615
dequantize_affine,
17-
dequantize_affine_floatx,
1816
quantize_affine,
19-
quantize_affine_float8,
2017
quantize_affine_floatx,
2118
)
2219
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
@@ -28,7 +25,6 @@
2825
"AffineQuantizedTensor",
2926
"register_layout",
3027
"to_affine_quantized_intx",
31-
"to_affine_quantized_floatx",
3228
"to_affine_quantized_intx_static",
3329
"to_affine_quantized_floatx_static",
3430
"to_affine_quantized_fpx",
@@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
121117
if output_dtype is None:
122118
output_dtype = self.dtype
123119

124-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
125-
126-
if isinstance(self._layout, FloatxTensorCoreLayout):
127-
int_data, scale = self.tensor_impl.get_plain()
128-
return dequantize_affine_floatx(
129-
int_data,
130-
scale,
131-
self._layout.ebits,
132-
self._layout.mbits,
133-
output_dtype=output_dtype,
134-
)
135-
else:
136-
data, scale, zero_point = self.tensor_impl.get_plain()
137-
dq = dequantize_affine(
138-
data,
139-
self.block_size,
140-
scale,
141-
zero_point,
142-
data.dtype,
143-
self.quant_min,
144-
self.quant_max,
145-
self.zero_point_domain,
146-
output_dtype=output_dtype,
147-
)
148-
from torchao.dtypes.uintx import TensorCoreTiledLayout
120+
data, scale, zero_point = self.tensor_impl.get_plain()
121+
dq = dequantize_affine(
122+
data,
123+
self.block_size,
124+
scale,
125+
zero_point,
126+
data.dtype,
127+
self.quant_min,
128+
self.quant_max,
129+
self.zero_point_domain,
130+
output_dtype=output_dtype,
131+
)
132+
from torchao.dtypes.uintx import TensorCoreTiledLayout
149133

150-
if isinstance(self._layout, TensorCoreTiledLayout):
151-
# need to return to original shape if tensor was padded
152-
# in preprocessing
153-
# TODO: we could add an API for this if there are more use cases
154-
# (e.g. dequant_post_process) in TensorImpl or Layout
155-
for dim, dim_size in enumerate(self.shape):
156-
dq = dq.narrow(dim, 0, dim_size)
157-
return dq
134+
if isinstance(self._layout, TensorCoreTiledLayout):
135+
# need to return to original shape if tensor was padded
136+
# in preprocessing
137+
# TODO: we could add an API for this if there are more use cases
138+
# (e.g. dequant_post_process) in TensorImpl or Layout
139+
for dim, dim_size in enumerate(self.shape):
140+
dq = dq.narrow(dim, 0, dim_size)
141+
return dq
158142

159143
def __tensor_flatten__(self):
160144
return ["tensor_impl"], [
@@ -272,7 +256,7 @@ def from_hp_to_intx(
272256
# Note: output will be uint8 tensor for sub byte tensors for now
273257

274258
data = _layout.post_process(data)
275-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
259+
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
276260
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
277261
return cls(
278262
tensor_impl,
@@ -417,36 +401,6 @@ def from_hp_to_fpx(
417401
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
418402
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
419403

420-
@classmethod
421-
def from_hp_to_float8(
422-
cls,
423-
input_float: torch.Tensor,
424-
target_dtype: torch.dtype,
425-
block_size: Tuple[int, ...],
426-
_layout: Layout = PlainLayout(),
427-
):
428-
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
429-
original_shape = input_float.shape
430-
scale = choose_qparams_affine_float8(
431-
input_float,
432-
target_dtype,
433-
target_dtype,
434-
)
435-
fp8_data = quantize_affine_float8(
436-
input_float,
437-
scale,
438-
target_dtype,
439-
)
440-
fp8_data = _layout.post_process(fp8_data)
441-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
442-
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
443-
return cls(
444-
tensor_impl,
445-
block_size,
446-
original_shape,
447-
dtype=input_float.dtype,
448-
)
449-
450404
@property
451405
def _layout(self) -> Layout:
452406
return self.tensor_impl._layout
@@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):
500454

501455
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
502456
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
503-
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
504457
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
505-
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
506458
# experimental will be merged in to floatx
507459
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
508460

torchao/dtypes/floatx/float8_layout.py

+54
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
addmm_float8_unwrapped_inference,
1919
preprocess_data,
2020
)
21+
from torchao.quantization.quant_primitives import (
22+
FP8_TYPES,
23+
choose_qparams_affine_float8,
24+
dequantize_affine_float8,
25+
quantize_affine_float8,
26+
)
2127
from torchao.utils import _is_float8_type, fill_defaults
2228

2329
aten = torch.ops.aten
@@ -209,6 +215,51 @@ def __repr__(self):
209215
)
210216

211217

218+
class Float8Tensor(AffineQuantizedTensor):
219+
"""
220+
Float8 quantized tensor subclass which inherits AffineQuantizedTensor class.
221+
"""
222+
223+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
224+
if output_dtype is None:
225+
output_dtype = self.dtype
226+
int_data, scale = self.tensor_impl.get_plain()
227+
return dequantize_affine_float8(
228+
int_data,
229+
scale,
230+
output_dtype=output_dtype,
231+
)
232+
233+
@classmethod
234+
def from_hp_to_float8(
235+
cls,
236+
input_float: torch.Tensor,
237+
target_dtype: torch.dtype,
238+
block_size: Tuple[int, ...],
239+
_layout: Layout = Float8Layout(),
240+
):
241+
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
242+
original_shape = input_float.shape
243+
scale = choose_qparams_affine_float8(
244+
input_float,
245+
target_dtype,
246+
)
247+
fp8_data = quantize_affine_float8(
248+
input_float,
249+
scale,
250+
target_dtype,
251+
)
252+
fp8_data = _layout.post_process(fp8_data)
253+
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
254+
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
255+
return cls(
256+
tensor_impl,
257+
block_size,
258+
original_shape,
259+
dtype=input_float.dtype,
260+
)
261+
262+
212263
##########################
213264
# Float8 Dispatch Kernels
214265
##########################
@@ -311,3 +362,6 @@ def _linear_fp_act_fp8_weight_impl(
311362
bias: Optional[torch.Tensor],
312363
):
313364
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)
365+
366+
367+
to_affine_quantized_float8 = Float8Tensor.from_hp_to_float8

torchao/prototype/quantization/autoquant_v2.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@
2727
from torchao.quantization.autoquant import (
2828
AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1,
2929
)
30-
from torchao.quantization.granularity import (
31-
PerRow,
32-
PerTensor,
33-
)
34-
from torchao.quantization.quant_primitives import (
35-
MappingType,
36-
ZeroPointDomain,
37-
)
30+
from torchao.quantization.granularity import PerRow, PerTensor
31+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
3832
from torchao.quantization.subclass import ( # noqa
3933
Int8DynamicallyQuantizedLinearWeight,
4034
Int8WeightOnlyQuantizedLinearWeight,
@@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(
991985
@classmethod
992986
def from_float(cls, weight):
993987
# avoid circular dep
994-
from torchao.dtypes import to_affine_quantized_floatx
988+
from torchao.dtypes import to_affine_quantized_float8
995989
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
996990

997991
# weight settings
@@ -1015,12 +1009,11 @@ def get_per_token_block_size(x):
10151009
activation_dtype=input_target_dtype,
10161010
)
10171011
block_size = get_weight_block_size(weight)
1018-
weight = to_affine_quantized_floatx(
1012+
weight = to_affine_quantized_float8(
10191013
input_float=weight,
10201014
block_size=block_size,
10211015
target_dtype=target_dtype,
10221016
_layout=_layout,
1023-
scale_dtype=torch.float32,
10241017
)
10251018
weight = super(
10261019
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls
@@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
10401033
@classmethod
10411034
def from_float(cls, weight):
10421035
# avoid circular dep
1043-
from torchao.dtypes import to_affine_quantized_floatx
1036+
from torchao.dtypes import to_affine_quantized_float8
10441037
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
10451038

10461039
# weight settings
@@ -1058,12 +1051,11 @@ def get_weight_block_size(x):
10581051
activation_dtype=input_target_dtype,
10591052
)
10601053
block_size = get_weight_block_size(weight)
1061-
weight = to_affine_quantized_floatx(
1054+
weight = to_affine_quantized_float8(
10621055
input_float=weight,
10631056
block_size=block_size,
10641057
target_dtype=target_dtype,
10651058
_layout=_layout,
1066-
scale_dtype=torch.float32,
10671059
)
10681060
weight = super(
10691061
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls

torchao/quantization/autoquant.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
LinearActivationQuantizedTensor,
1919
to_linear_activation_quantized,
2020
)
21-
from torchao.quantization.quant_primitives import (
22-
MappingType,
23-
ZeroPointDomain,
24-
)
21+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
2522
from torchao.quantization.utils import (
2623
compute_error,
2724
quantize_activation_per_token_absmax,
@@ -34,10 +31,7 @@
3431
is_sm_at_least_90,
3532
)
3633

37-
from .granularity import (
38-
PerRow,
39-
PerTensor,
40-
)
34+
from .granularity import PerRow, PerTensor
4135
from .subclass import ( # noqa
4236
Int8DynamicallyQuantizedLinearWeight,
4337
Int8WeightOnlyQuantizedLinearWeight,
@@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten
969963
@classmethod
970964
def from_float(cls, weight):
971965
# avoid circular dep
972-
from torchao.dtypes import to_affine_quantized_floatx
966+
from torchao.dtypes import to_affine_quantized_float8
973967
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
974968

975969
# weight settings
@@ -995,12 +989,11 @@ def get_per_token_block_size(x):
995989
}
996990
block_size = get_weight_block_size(weight)
997991

998-
weight = to_affine_quantized_floatx(
992+
weight = to_affine_quantized_float8(
999993
input_float=weight,
1000994
block_size=block_size,
1001995
target_dtype=target_dtype,
1002996
_layout=_layout,
1003-
scale_dtype=torch.float32,
1004997
)
1005998
weight = to_linear_activation_quantized(
1006999
weight, input_quant_func, quant_kwargs=input_quant_kwargs
@@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
10251018
@classmethod
10261019
def from_float(cls, weight):
10271020
# avoid circular dep
1028-
from torchao.dtypes import to_affine_quantized_floatx
1021+
from torchao.dtypes import to_affine_quantized_float8
10291022
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
10301023

10311024
# weight settings
@@ -1043,12 +1036,11 @@ def get_weight_block_size(x):
10431036
"activation_dtype": input_target_dtype,
10441037
}
10451038
block_size = get_weight_block_size(weight)
1046-
weight = to_affine_quantized_floatx(
1039+
weight = to_affine_quantized_float8(
10471040
input_float=weight,
10481041
block_size=block_size,
10491042
target_dtype=target_dtype,
10501043
_layout=_layout,
1051-
scale_dtype=torch.float32,
10521044
)
10531045
weight = super(
10541046
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls

0 commit comments

Comments
 (0)