Skip to content

Commit 9ecdb3b

Browse files
committed
Move fpx to tensor subclass
1 parent 5d1444b commit 9ecdb3b

File tree

4 files changed

+80
-67
lines changed

4 files changed

+80
-67
lines changed

torchao/dtypes/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
to_affine_quantized_floatx,
55
to_affine_quantized_floatx_static,
66
# experimental, will be merged into floatx in the future
7-
to_affine_quantized_fpx,
87
to_affine_quantized_intx,
98
to_affine_quantized_intx_static,
109
)
1110
from .floatx import (
1211
Float8Layout,
12+
to_affine_quantized_fpx,
1313
)
1414
from .nf4tensor import NF4Tensor, to_nf4
1515
from .uintx import (

torchao/dtypes/affine_quantized_tensor.py

+21-66
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
MappingType,
1515
ZeroPointDomain,
1616
choose_qparams_affine,
17-
choose_qparams_affine_floatx,
1817
choose_qparams_and_quantize_affine_hqq,
1918
dequantize_affine,
20-
dequantize_affine_floatx,
2119
quantize_affine,
22-
quantize_affine_floatx,
2320
)
2421
from torchao.utils import (
2522
TORCH_VERSION_AT_LEAST_2_5,
@@ -36,7 +33,6 @@
3633
"to_affine_quantized_floatx",
3734
"to_affine_quantized_intx_static",
3835
"to_affine_quantized_floatx_static",
39-
"to_affine_quantized_fpx",
4036
]
4137

4238

@@ -126,40 +122,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
126122
if output_dtype is None:
127123
output_dtype = self.dtype
128124

129-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
130-
131-
if isinstance(self._layout, FloatxTensorCoreLayout):
132-
int_data, scale = self.tensor_impl.get_plain()
133-
return dequantize_affine_floatx(
134-
int_data,
135-
scale,
136-
self._layout.ebits,
137-
self._layout.mbits,
138-
output_dtype=output_dtype,
139-
)
140-
else:
141-
data, scale, zero_point = self.tensor_impl.get_plain()
142-
dq = dequantize_affine(
143-
data,
144-
self.block_size,
145-
scale,
146-
zero_point,
147-
data.dtype,
148-
self.quant_min,
149-
self.quant_max,
150-
self.zero_point_domain,
151-
output_dtype=output_dtype,
152-
)
153-
from torchao.dtypes.uintx import TensorCoreTiledLayout
125+
data, scale, zero_point = self.tensor_impl.get_plain()
126+
dq = dequantize_affine(
127+
data,
128+
self.block_size,
129+
scale,
130+
zero_point,
131+
data.dtype,
132+
self.quant_min,
133+
self.quant_max,
134+
self.zero_point_domain,
135+
output_dtype=output_dtype,
136+
)
137+
from torchao.dtypes.uintx import TensorCoreTiledLayout
154138

155-
if isinstance(self._layout, TensorCoreTiledLayout):
156-
# need to return to original shape if tensor was padded
157-
# in preprocessing
158-
# TODO: we could add an API for this if there are more use cases
159-
# (e.g. dequant_post_process) in TensorImpl or Layout
160-
for dim, dim_size in enumerate(self.shape):
161-
dq = dq.narrow(dim, 0, dim_size)
162-
return dq
139+
if isinstance(self._layout, TensorCoreTiledLayout):
140+
# need to return to original shape if tensor was padded
141+
# in preprocessing
142+
# TODO: we could add an API for this if there are more use cases
143+
# (e.g. dequant_post_process) in TensorImpl or Layout
144+
for dim, dim_size in enumerate(self.shape):
145+
dq = dq.narrow(dim, 0, dim_size)
146+
return dq
163147

164148
def __tensor_flatten__(self):
165149
return ["tensor_impl"], [
@@ -395,33 +379,6 @@ def from_hp_to_floatx_static(
395379
f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static"
396380
)
397381

398-
@classmethod
399-
def from_hp_to_fpx(
400-
cls,
401-
input_float: torch.Tensor,
402-
_layout: Layout,
403-
):
404-
from torchao.dtypes.floatx import FloatxTensorCoreLayout
405-
406-
assert isinstance(
407-
_layout, FloatxTensorCoreLayout
408-
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
409-
original_shape = input_float.shape
410-
input_float = _layout.pre_process(input_float)
411-
# per axis quantization, where axis = 1
412-
block_size = list(input_float.shape)
413-
block_size[1] = 1
414-
415-
ebits, mbits = _layout.ebits, _layout.mbits
416-
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
417-
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
418-
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
419-
floatx_packed = _layout.post_process(floatx_unpacked)
420-
421-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
422-
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
423-
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
424-
425382
@property
426383
def _layout(self) -> Layout:
427384
return self.tensor_impl._layout
@@ -477,8 +434,6 @@ def _apply_fn_to_data(self, fn):
477434
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
478435
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
479436
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
480-
# experimental will be merged in to floatx
481-
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
482437

483438
if TORCH_VERSION_AT_LEAST_2_5:
484439
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`

torchao/dtypes/floatx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .floatx_tensor_core_layout import (
33
FloatxTensorCoreLayout,
44
from_scaled_tc_floatx,
5+
to_affine_quantized_fpx,
56
to_scaled_tc_floatx,
67
)
78

@@ -10,4 +11,5 @@
1011
"to_scaled_tc_floatx",
1112
"from_scaled_tc_floatx",
1213
"Float8Layout",
14+
"to_affine_quantized_fpx",
1315
]

torchao/dtypes/floatx/floatx_tensor_core_layout.py

+56
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchao.dtypes.affine_quantized_tensor import (
1313
AffineQuantizedTensor,
14+
get_tensor_impl_constructor,
1415
register_layout,
1516
)
1617
from torchao.dtypes.utils import (
@@ -22,6 +23,11 @@
2223
_floatx_unpacked_to_f32,
2324
_n_ones,
2425
)
26+
from torchao.quantization.quant_primitives import (
27+
choose_qparams_affine_floatx,
28+
dequantize_affine_floatx,
29+
quantize_affine_floatx,
30+
)
2531

2632
aten = torch.ops.aten
2733
_ONES_TABLE = [_n_ones(i) for i in range(8)]
@@ -456,6 +462,53 @@ class FloatxTensorCoreLayout(Layout):
456462
mbits: int
457463

458464

465+
class FloatxTensor(AffineQuantizedTensor):
466+
"""
467+
Floatx quantized tensor subclass which inherits AffineQuantizedTensor class.
468+
469+
To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization,
470+
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
471+
and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx.
472+
"""
473+
474+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
475+
if output_dtype is None:
476+
output_dtype = self.dtype
477+
int_data, scale = self.tensor_impl.get_plain()
478+
return dequantize_affine_floatx(
479+
int_data,
480+
scale,
481+
self._layout.ebits,
482+
self._layout.mbits,
483+
output_dtype=output_dtype,
484+
)
485+
486+
@classmethod
487+
def from_hp_to_floatx(
488+
cls,
489+
input_float: torch.Tensor,
490+
_layout: Layout,
491+
):
492+
assert isinstance(
493+
_layout, FloatxTensorCoreLayout
494+
), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}"
495+
original_shape = input_float.shape
496+
input_float = _layout.pre_process(input_float)
497+
# per axis quantization, where axis = 1
498+
block_size = list(input_float.shape)
499+
block_size[1] = 1
500+
501+
ebits, mbits = _layout.ebits, _layout.mbits
502+
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
503+
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
504+
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
505+
floatx_packed = _layout.post_process(floatx_unpacked)
506+
507+
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
508+
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
509+
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
510+
511+
459512
@register_layout(FloatxTensorCoreLayout)
460513
class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl):
461514
"""FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b),
@@ -657,3 +710,6 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias):
657710
out += bias
658711

659712
return out.view(*act.shape[:-1], out_dim).to(act.dtype)
713+
714+
715+
to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx

0 commit comments

Comments
 (0)