Skip to content

Commit ae88acf

Browse files
replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs
ghstack-source-id: 61cc8c2 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
1 parent ce39c63 commit ae88acf

File tree

9 files changed

+133
-149
lines changed

9 files changed

+133
-149
lines changed

docs/source/api_ref_dtypes.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ torchao.dtypes
1313
to_nf4
1414
to_affine_quantized_intx
1515
to_affine_quantized_intx_static
16-
to_affine_quantized_floatx
16+
to_affine_quantized_float8
1717
to_affine_quantized_floatx_static
1818
to_affine_quantized_fpx
1919
NF4Tensor

torchao/dtypes/__init__.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from . import affine_quantized_tensor_ops
22
from .affine_quantized_tensor import (
33
AffineQuantizedTensor,
4-
to_affine_quantized_floatx,
54
to_affine_quantized_floatx_static,
65
# experimental, will be merged into floatx in the future
76
to_affine_quantized_fpx,
87
to_affine_quantized_intx,
98
to_affine_quantized_intx_static,
109
)
11-
from .floatx import (
12-
Float8Layout,
13-
)
10+
from .floatx import Float8Layout, Float8QuantizedTensor, to_affine_quantized_float8
1411
from .nf4tensor import NF4Tensor, to_nf4
1512
from .uintx import (
1613
BlockSparseLayout,
@@ -24,20 +21,18 @@
2421
UintxLayout,
2522
to_marlinqqq_quantized_intx,
2623
)
27-
from .utils import (
28-
Layout,
29-
PlainLayout,
30-
)
24+
from .utils import Layout, PlainLayout
3125

3226
__all__ = [
3327
"NF4Tensor",
3428
"to_nf4",
3529
"AffineQuantizedTensor",
30+
"Float8QuantizedTensor",
3631
"to_affine_quantized_intx",
3732
"to_affine_quantized_intx_static",
3833
"to_affine_quantized_fpx",
39-
"to_affine_quantized_floatx",
4034
"to_affine_quantized_floatx_static",
35+
"to_affine_quantized_float8",
4136
"to_marlinqqq_quantized_intx",
4237
"Layout",
4338
"PlainLayout",

torchao/dtypes/affine_quantized_tensor.py

+22-70
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
MappingType,
1111
ZeroPointDomain,
1212
choose_qparams_affine,
13-
choose_qparams_affine_float8,
1413
choose_qparams_affine_floatx,
1514
choose_qparams_and_quantize_affine_hqq,
1615
dequantize_affine,
17-
dequantize_affine_floatx,
1816
quantize_affine,
19-
quantize_affine_float8,
2017
quantize_affine_floatx,
2118
)
2219
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
@@ -28,7 +25,6 @@
2825
"AffineQuantizedTensor",
2926
"register_layout",
3027
"to_affine_quantized_intx",
31-
"to_affine_quantized_floatx",
3228
"to_affine_quantized_intx_static",
3329
"to_affine_quantized_floatx_static",
3430
"to_affine_quantized_fpx",
@@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
121117
if output_dtype is None:
122118
output_dtype = self.dtype
123119

124-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
125-
126-
if isinstance(self._layout, FloatxTensorCoreLayout):
127-
int_data, scale = self.tensor_impl.get_plain()
128-
return dequantize_affine_floatx(
129-
int_data,
130-
scale,
131-
self._layout.ebits,
132-
self._layout.mbits,
133-
output_dtype=output_dtype,
134-
)
135-
else:
136-
data, scale, zero_point = self.tensor_impl.get_plain()
137-
dq = dequantize_affine(
138-
data,
139-
self.block_size,
140-
scale,
141-
zero_point,
142-
data.dtype,
143-
self.quant_min,
144-
self.quant_max,
145-
self.zero_point_domain,
146-
output_dtype=output_dtype,
147-
)
148-
from torchao.dtypes.uintx import TensorCoreTiledLayout
120+
data, scale, zero_point = self.tensor_impl.get_plain()
121+
dq = dequantize_affine(
122+
data,
123+
self.block_size,
124+
scale,
125+
zero_point,
126+
data.dtype,
127+
self.quant_min,
128+
self.quant_max,
129+
self.zero_point_domain,
130+
output_dtype=output_dtype,
131+
)
132+
from torchao.dtypes.uintx import TensorCoreTiledLayout
149133

150-
if isinstance(self._layout, TensorCoreTiledLayout):
151-
# need to return to original shape if tensor was padded
152-
# in preprocessing
153-
# TODO: we could add an API for this if there are more use cases
154-
# (e.g. dequant_post_process) in TensorImpl or Layout
155-
for dim, dim_size in enumerate(self.shape):
156-
dq = dq.narrow(dim, 0, dim_size)
157-
return dq
134+
if isinstance(self._layout, TensorCoreTiledLayout):
135+
# need to return to original shape if tensor was padded
136+
# in preprocessing
137+
# TODO: we could add an API for this if there are more use cases
138+
# (e.g. dequant_post_process) in TensorImpl or Layout
139+
for dim, dim_size in enumerate(self.shape):
140+
dq = dq.narrow(dim, 0, dim_size)
141+
return dq
158142

159143
def __tensor_flatten__(self):
160144
return ["tensor_impl"], [
@@ -272,7 +256,7 @@ def from_hp_to_intx(
272256
# Note: output will be uint8 tensor for sub byte tensors for now
273257

274258
data = _layout.post_process(data)
275-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
259+
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
276260
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
277261
return cls(
278262
tensor_impl,
@@ -417,36 +401,6 @@ def from_hp_to_fpx(
417401
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
418402
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
419403

420-
@classmethod
421-
def from_hp_to_float8(
422-
cls,
423-
input_float: torch.Tensor,
424-
target_dtype: torch.dtype,
425-
block_size: Tuple[int, ...],
426-
_layout: Layout = PlainLayout(),
427-
):
428-
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
429-
original_shape = input_float.shape
430-
scale = choose_qparams_affine_float8(
431-
input_float,
432-
target_dtype,
433-
target_dtype,
434-
)
435-
fp8_data = quantize_affine_float8(
436-
input_float,
437-
scale,
438-
target_dtype,
439-
)
440-
fp8_data = _layout.post_process(fp8_data)
441-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
442-
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
443-
return cls(
444-
tensor_impl,
445-
block_size,
446-
original_shape,
447-
dtype=input_float.dtype,
448-
)
449-
450404
@property
451405
def _layout(self) -> Layout:
452406
return self.tensor_impl._layout
@@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):
500454

501455
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
502456
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
503-
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
504457
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
505-
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
506458
# experimental will be merged in to floatx
507459
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
508460

torchao/dtypes/affine_quantized_tensor_ops.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import torch
44
from torch.utils._python_dispatch import return_and_correct_aliasing
55

6-
from torchao.dtypes.affine_quantized_tensor import (
7-
AffineQuantizedTensor,
8-
)
6+
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
97
from torchao.dtypes.floatx.float8_layout import (
108
_linear_fp8_act_fp8_weight_check,
119
_linear_fp8_act_fp8_weight_impl,
@@ -52,9 +50,7 @@
5250
_linear_bf16_act_uint4_weight_impl,
5351
)
5452
from torchao.quantization.quant_primitives import dequantize_affine
55-
from torchao.utils import (
56-
fill_defaults,
57-
)
53+
from torchao.utils import fill_defaults
5854

5955
logger = logging.getLogger(__name__)
6056

torchao/dtypes/floatx/__init__.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from .float8_layout import Float8Layout
1+
from .float8_layout import (
2+
Float8Layout,
3+
Float8QuantizedTensor,
4+
_linear_fp8_act_fp8_weight_check,
5+
_linear_fp8_act_fp8_weight_impl,
6+
_linear_fp_act_fp8_weight_check,
7+
_linear_fp_act_fp8_weight_impl,
8+
to_affine_quantized_float8,
9+
)
210
from .floatx_tensor_core_layout import (
311
FloatxTensorCoreLayout,
412
from_scaled_tc_floatx,
@@ -7,7 +15,13 @@
715

816
__all__ = [
917
"FloatxTensorCoreLayout",
18+
"Float8Layout",
19+
"Float8QuantizedTensor",
1020
"to_scaled_tc_floatx",
1121
"from_scaled_tc_floatx",
12-
"Float8Layout",
22+
"to_affine_quantized_float8",
23+
"_linear_fp8_act_fp8_weight_check",
24+
"_linear_fp8_act_fp8_weight_impl",
25+
"_linear_fp_act_fp8_weight_check",
26+
"_linear_fp_act_fp8_weight_impl",
1327
]

torchao/dtypes/floatx/float8_layout.py

+64-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
addmm_float8_unwrapped_inference,
1919
preprocess_data,
2020
)
21+
from torchao.quantization.quant_primitives import (
22+
FP8_TYPES,
23+
choose_qparams_affine_float8,
24+
dequantize_affine_float8,
25+
quantize_affine_float8,
26+
)
2127
from torchao.utils import _is_float8_type, fill_defaults
2228

2329
aten = torch.ops.aten
@@ -209,19 +215,64 @@ def __repr__(self):
209215
)
210216

211217

218+
class Float8QuantizedTensor(AffineQuantizedTensor):
219+
"""
220+
Float8 quantized tensor subclass which inherits Float8QuantizedTensor class.
221+
"""
222+
223+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
224+
if output_dtype is None:
225+
output_dtype = self.dtype
226+
int_data, scale, _ = self.tensor_impl.get_plain()
227+
return dequantize_affine_float8(
228+
int_data,
229+
scale,
230+
output_dtype=output_dtype,
231+
)
232+
233+
@classmethod
234+
def from_hp_to_float8(
235+
cls,
236+
input_float: torch.Tensor,
237+
target_dtype: torch.dtype,
238+
block_size: Tuple[int, ...],
239+
_layout: Layout = Float8Layout(),
240+
):
241+
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
242+
original_shape = input_float.shape
243+
scale = choose_qparams_affine_float8(
244+
input_float,
245+
target_dtype,
246+
)
247+
fp8_data = quantize_affine_float8(
248+
input_float,
249+
scale,
250+
target_dtype,
251+
)
252+
fp8_data = _layout.post_process(fp8_data)
253+
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
254+
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
255+
return cls(
256+
tensor_impl,
257+
block_size,
258+
original_shape,
259+
dtype=input_float.dtype,
260+
)
261+
262+
212263
##########################
213264
# Float8 Dispatch Kernels
214265
##########################
215266

216267

217268
def _linear_fp8_act_fp8_weight_check(
218-
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
219-
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
269+
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
270+
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
220271
bias: Optional[torch.Tensor],
221272
) -> bool:
222-
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
273+
def check_aqt(aqt: Union[torch.Tensor, Float8QuantizedTensor]) -> bool:
223274
return (
224-
isinstance(aqt, AffineQuantizedTensor)
275+
isinstance(aqt, Float8QuantizedTensor)
225276
and isinstance(aqt._layout, Float8Layout)
226277
and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
227278
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
@@ -241,8 +292,8 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
241292

242293

243294
def _linear_fp8_act_fp8_weight_impl(
244-
input_tensor: "AffineQuantizedTensor",
245-
weight_tensor: "AffineQuantizedTensor",
295+
input_tensor: "Float8QuantizedTensor",
296+
weight_tensor: "Float8QuantizedTensor",
246297
bias: Optional[torch.Tensor],
247298
):
248299
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
@@ -285,8 +336,8 @@ def _linear_fp8_act_fp8_weight_impl(
285336

286337

287338
def _linear_fp_act_fp8_weight_check(
288-
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
289-
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
339+
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
340+
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
290341
bias: Optional[torch.Tensor],
291342
) -> bool:
292343
return (
@@ -295,7 +346,7 @@ def _linear_fp_act_fp8_weight_check(
295346
and input_tensor.is_floating_point()
296347
and
297348
# weight is float8 quantized affine quantized tensor
298-
isinstance(weight_tensor, AffineQuantizedTensor)
349+
isinstance(weight_tensor, Float8QuantizedTensor)
299350
and isinstance(weight_tensor._layout, Float8Layout)
300351
and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
301352
and (
@@ -307,7 +358,10 @@ def _linear_fp_act_fp8_weight_check(
307358

308359
def _linear_fp_act_fp8_weight_impl(
309360
input_tensor: torch.Tensor,
310-
weight_tensor: "AffineQuantizedTensor",
361+
weight_tensor: "Float8QuantizedTensor",
311362
bias: Optional[torch.Tensor],
312363
):
313364
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)
365+
366+
367+
to_affine_quantized_float8 = Float8QuantizedTensor.from_hp_to_float8

0 commit comments

Comments
 (0)